Source code for marshmallow_sqlalchemy.fields

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, cast

from marshmallow import fields
from marshmallow.utils import is_iterable_but_not_string
from sqlalchemy import inspect
from sqlalchemy.orm.exc import NoResultFound

if TYPE_CHECKING:
    from sqlalchemy.ext.declarative import DeclarativeMeta
    from sqlalchemy.orm import MapperProperty


[docs] class RelatedList(fields.List): def get_value(self, obj, attr, accessor=None): # Do not call `fields.List`'s get_value as it calls the container's # `get_value` if the container has `attribute`. # Instead call the `get_value` from the parent of `fields.List` # so the special handling is avoided. return super(fields.List, self).get_value(obj, attr, accessor=accessor)
[docs] class Nested(fields.Nested): """Nested field that inherits the session from its parent.""" def _deserialize(self, *args, **kwargs): if hasattr(self.schema, "session"): self.schema.session = self.root.session self.schema.transient = self.root.transient return super()._deserialize(*args, **kwargs)
def get_primary_keys(model: type[DeclarativeMeta]) -> list[MapperProperty]: """Get primary key properties for a SQLAlchemy model. :param model: SQLAlchemy model class """ mapper = model.__mapper__ # type: ignore[attr-defined] return [mapper.get_property_by_column(column) for column in mapper.primary_key] def ensure_list(value: Any) -> list: return list(value) if is_iterable_but_not_string(value) else [value]