Source code for marshmallow_sqlalchemy.schema

from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Any, cast

import sqlalchemy as sa
from marshmallow.fields import Field
from marshmallow.schema import Schema, SchemaMeta, SchemaOpts, _get_fields

from .convert import ModelConverter
from .exceptions import IncorrectSchemaTypeError
from .load_instance_mixin import LoadInstanceMixin, _ModelType

if TYPE_CHECKING:
    from sqlalchemy.ext.declarative import DeclarativeMeta


# This isn't really a field; it's a placeholder for the metaclass.
# This should be considered private API.
class SQLAlchemyAutoField(Field):
    def __init__(
        self,
        *,
        column_name: str | None = None,
        model: type[DeclarativeMeta] | None = None,
        table: sa.Table | None = None,
        field_kwargs: dict[str, Any],
    ):
        super().__init__()

        if model and table:
            raise ValueError("Cannot pass both `model` and `table` options.")

        self.column_name = column_name
        self.model = model
        self.table = table
        self.field_kwargs = field_kwargs

    def create_field(
        self,
        schema_opts: SQLAlchemySchemaOpts,
        column_name: str,
        converter: ModelConverter,
    ):
        model = self.model or schema_opts.model
        if model:
            return converter.field_for(model, column_name, **self.field_kwargs)
        table = self.table if self.table is not None else schema_opts.table
        column = getattr(cast(sa.Table, table).columns, column_name)
        return converter.column2field(column, **self.field_kwargs)

    # This field should never be bound to a schema.
    # If this method is called, it's probably because the schema is not a SQLAlchemySchema.
    def _bind_to_schema(self, field_name: str, parent: Schema | Field) -> None:
        raise IncorrectSchemaTypeError(
            f"Cannot bind SQLAlchemyAutoField. Make sure that {parent} is a SQLAlchemySchema or SQLAlchemyAutoSchema."
        )


[docs] class SQLAlchemySchemaOpts(LoadInstanceMixin.Opts, SchemaOpts): """Options class for `SQLAlchemySchema`. Adds the following options: - ``model``: The SQLAlchemy model to generate the `Schema` from (mutually exclusive with ``table``). - ``table``: The SQLAlchemy table to generate the `Schema` from (mutually exclusive with ``model``). - ``load_instance``: Whether to load model instances. - ``sqla_session``: SQLAlchemy session to be used for deserialization. This is only needed when ``load_instance`` is `True`. You can also pass a session to the Schema's `load` method. - ``transient``: Whether to load model instances in a transient state (effectively ignoring the session). Only relevant when ``load_instance`` is `True`. - ``model_converter``: `ModelConverter` class to use for converting the SQLAlchemy model to marshmallow fields. """ table: sa.Table | None model_converter: type[ModelConverter] def __init__(self, meta, *args, **kwargs): super().__init__(meta, *args, **kwargs) self.table = getattr(meta, "table", None) if self.model is not None and self.table is not None: raise ValueError("Cannot set both `model` and `table` options.") self.model_converter = getattr(meta, "model_converter", ModelConverter)
[docs] class SQLAlchemyAutoSchemaOpts(SQLAlchemySchemaOpts): """Options class for `SQLAlchemyAutoSchema`. Has the same options as `SQLAlchemySchemaOpts`, with the addition of: - ``include_fk``: Whether to include foreign fields; defaults to `False`. - ``include_relationships``: Whether to include relationships; defaults to `False`. """ include_fk: bool include_relationships: bool def __init__(self, meta, *args, **kwargs): super().__init__(meta, *args, **kwargs) self.include_fk = getattr(meta, "include_fk", False) self.include_relationships = getattr(meta, "include_relationships", False) if self.table is not None and self.include_relationships: raise ValueError("Cannot set `table` and `include_relationships = True`.")
class SQLAlchemySchemaMeta(SchemaMeta): @classmethod def get_declared_fields( mcs, klass, cls_fields: list[tuple[str, Field]], inherited_fields: list[tuple[str, Field]], dict_cls: type[dict] = dict, ) -> dict[str, Field]: opts = klass.opts Converter: type[ModelConverter] = opts.model_converter converter = Converter(schema_cls=klass) fields = super().get_declared_fields( klass, cls_fields, # Filter out fields generated from foreign key columns # if include_fk is set to False in the options mcs._maybe_filter_foreign_keys(inherited_fields, opts=opts, klass=klass), dict_cls, ) fields.update(mcs.get_declared_sqla_fields(fields, converter, opts, dict_cls)) fields.update(mcs.get_auto_fields(fields, converter, opts, dict_cls)) return fields @classmethod def get_declared_sqla_fields( mcs, base_fields: dict[str, Field], converter: ModelConverter, opts: Any, dict_cls: type[dict], ) -> dict[str, Field]: return {} @classmethod def get_auto_fields( mcs, fields: dict[str, Field], converter: ModelConverter, opts: Any, dict_cls: type[dict], ) -> dict[str, Field]: return dict_cls( { field_name: field.create_field( opts, field.column_name or field_name, converter ) for field_name, field in fields.items() if isinstance(field, SQLAlchemyAutoField) and field_name not in opts.exclude } ) @staticmethod def _maybe_filter_foreign_keys( fields: list[tuple[str, Field]], *, opts: SQLAlchemySchemaOpts, klass: SchemaMeta, ) -> list[tuple[str, Field]]: if opts.model is not None or opts.table is not None: if not hasattr(opts, "include_fk") or opts.include_fk is True: return fields foreign_keys = { column.key for column in sa.inspect(opts.model or opts.table).columns # type: ignore[union-attr] if column.foreign_keys } non_auto_schema_bases = [ base for base in inspect.getmro(klass) if issubclass(base, Schema) and not issubclass(base, SQLAlchemyAutoSchema) ] def is_declared_field(field: str) -> bool: return any( field in [ name for name, _ in _get_fields( getattr(base, "_declared_fields", base.__dict__) ) ] for base in non_auto_schema_bases ) return [ (name, field) for name, field in fields if name not in foreign_keys or is_declared_field(name) ] return fields class SQLAlchemyAutoSchemaMeta(SQLAlchemySchemaMeta): @classmethod def get_declared_sqla_fields( cls, base_fields, converter: ModelConverter, opts, dict_cls ): fields = dict_cls() if opts.table is not None: fields.update( converter.fields_for_table( opts.table, fields=opts.fields, exclude=opts.exclude, include_fk=opts.include_fk, base_fields=base_fields, dict_cls=dict_cls, ) ) elif opts.model is not None: fields.update( converter.fields_for_model( opts.model, fields=opts.fields, exclude=opts.exclude, include_fk=opts.include_fk, include_relationships=opts.include_relationships, base_fields=base_fields, dict_cls=dict_cls, ) ) return fields
[docs] class SQLAlchemySchema( LoadInstanceMixin.Schema[_ModelType], Schema, metaclass=SQLAlchemySchemaMeta ): """Schema for a SQLAlchemy model or table. Use together with `auto_field` to generate fields from columns. Example: :: from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field from mymodels import User class UserSchema(SQLAlchemySchema): class Meta: model = User id = auto_field() created_at = auto_field(dump_only=True) name = auto_field() """ OPTIONS_CLASS = SQLAlchemySchemaOpts
[docs] class SQLAlchemyAutoSchema( SQLAlchemySchema[_ModelType], metaclass=SQLAlchemyAutoSchemaMeta ): """Schema that automatically generates fields from the columns of a SQLAlchemy model or table. Example: :: from marshmallow_sqlalchemy import SQLAlchemyAutoSchema, auto_field from mymodels import User class UserSchema(SQLAlchemyAutoSchema): class Meta: model = User # OR # table = User.__table__ created_at = auto_field(dump_only=True) """ OPTIONS_CLASS = SQLAlchemyAutoSchemaOpts
[docs] def auto_field( column_name: str | None = None, *, model: type[DeclarativeMeta] | None = None, table: sa.Table | None = None, # TODO: add type annotations for **kwargs **kwargs, ) -> SQLAlchemyAutoField: """Mark a field to autogenerate from a model or table. :param column_name: Name of the column to generate the field from. If ``None``, matches the field name. If ``attribute`` is unspecified, ``attribute`` will be set to the same value as ``column_name``. :param model: Model to generate the field from. If ``None``, uses ``model`` specified on ``class Meta``. :param table: Table to generate the field from. If ``None``, uses ``table`` specified on ``class Meta``. :param kwargs: Field argument overrides. """ if column_name is not None: kwargs.setdefault("attribute", column_name) return SQLAlchemyAutoField( column_name=column_name, model=model, table=table, field_kwargs=kwargs )