Source code for wtforms_alchemy.generator
try:
from collections import OrderedDict
except ImportError:
from ordereddict import OrderedDict
try:
from enum import Enum
except ImportError:
Enum = None
import inspect
from decimal import Decimal
import six
import sqlalchemy as sa
from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy_utils import types
from wtforms import (
BooleanField,
Field,
FloatField,
PasswordField,
TextAreaField
)
from wtforms.widgets import CheckboxInput, TextArea
from wtforms_components import (
ColorField,
DateField,
DateIntervalField,
DateTimeField,
DateTimeIntervalField,
DateTimeLocalField,
DecimalField,
DecimalIntervalField,
EmailField,
IntegerField,
IntIntervalField,
SelectField,
StringField,
TimeField
)
from wtforms_components.widgets import (
ColorInput,
DateInput,
DateTimeInput,
DateTimeLocalInput,
EmailInput,
NumberInput,
TextInput,
TimeInput
)
from .exc import (
AttributeTypeException,
InvalidAttributeException,
UnknownTypeException
)
from .fields import CountryField, PhoneNumberField, WeekDaysField
from .utils import (
choice_type_coerce_factory,
ClassMap,
flatten,
is_date_column,
is_number,
is_number_range,
is_scalar,
null_or_unicode,
strip_string,
translated_attributes
)
[docs]class FormGenerator(object):
"""
Base form generator, you can make your own form generators by inheriting
this class.
"""
# When converting SQLAlchemy types to fields this ordered dict is iterated
# in given order. This allows smart type conversion of different inherited
# type objects.
TYPE_MAP = ClassMap((
(sa.types.UnicodeText, TextAreaField),
(sa.types.BigInteger, IntegerField),
(sa.types.SmallInteger, IntegerField),
(sa.types.Text, TextAreaField),
(sa.types.Boolean, BooleanField),
(sa.types.Date, DateField),
(sa.types.DateTime, DateTimeField),
(sa.types.Enum, SelectField),
(sa.types.Float, FloatField),
(sa.types.Integer, IntegerField),
(sa.types.Numeric, DecimalField),
(sa.types.Unicode, StringField),
(sa.types.String, StringField),
(sa.types.Time, TimeField),
(types.ArrowType, DateTimeField),
(types.ChoiceType, SelectField),
(types.ColorType, ColorField),
(types.CountryType, CountryField),
(types.DateRangeType, DateIntervalField),
(types.DateTimeRangeType, DateTimeIntervalField),
(types.EmailType, EmailField),
(types.IntRangeType, IntIntervalField),
(types.NumericRangeType, DecimalIntervalField),
(types.PasswordType, PasswordField),
(types.PhoneNumberType, PhoneNumberField),
(types.ScalarListType, StringField),
(types.URLType, StringField),
(types.UUIDType, StringField),
(types.WeekDaysType, WeekDaysField),
))
WIDGET_MAP = OrderedDict((
(BooleanField, CheckboxInput),
(ColorField, ColorInput),
(DateField, DateInput),
(DateTimeField, DateTimeInput),
(DateTimeLocalField, DateTimeLocalInput),
(DecimalField, NumberInput),
(EmailField, EmailInput),
(FloatField, NumberInput),
(IntegerField, NumberInput),
(TextAreaField, TextArea),
(TimeField, TimeInput),
(StringField, TextInput)
))
def __init__(self, form_class):
"""
Initializes the form generator
:param form_class: ModelForm class to be used as the base of generation
process
"""
self.form_class = form_class
self.model_class = self.form_class.Meta.model
self.meta = self.form_class.Meta
self.TYPE_MAP.update(self.form_class.Meta.type_map)
[docs] def create_form(self, form):
"""
Creates the form.
:param form: ModelForm instance
"""
attrs = OrderedDict()
for key, property_ in sa.inspect(self.model_class).attrs.items():
if not isinstance(property_, ColumnProperty):
continue
if self.skip_column_property(property_):
continue
attrs[key] = property_
for attr in translated_attributes(self.model_class):
attrs[attr.key] = attr.property
return self.create_fields(form, self.filter_attributes(attrs))
[docs] def filter_attributes(self, attrs):
"""
Filter set of model attributes based on only, exclude and include
meta parameters.
:param attrs: Set of attributes
"""
if self.meta.only:
attrs = OrderedDict([
(key, prop)
for key, prop in map(self.validate_attribute, self.meta.only)
if key
])
else:
if self.meta.include:
attrs.update([
(key, prop)
for key, prop
in map(self.validate_attribute, self.meta.include)
if key
])
if self.meta.exclude:
for key in self.meta.exclude:
try:
del attrs[key]
except KeyError:
if self.meta.attr_errors:
raise InvalidAttributeException(key)
return attrs
[docs] def validate_attribute(self, attr_name):
"""
Finds out whether or not given sqlalchemy model attribute name is
valid. Returns attribute property if valid.
:param attr_name: Attribute name
"""
try:
attr = getattr(self.model_class, attr_name)
except AttributeError:
try:
translation_class = (
self.model_class.__translatable__['class']
)
attr = getattr(translation_class, attr_name)
except AttributeError:
if self.meta.attr_errors:
raise InvalidAttributeException(attr_name)
else:
return None, None
try:
if not isinstance(attr.property, ColumnProperty):
if self.meta.attr_errors:
raise InvalidAttributeException(attr_name)
else:
return None, None
except AttributeError:
raise AttributeTypeException(attr_name)
return attr_name, attr.property
[docs] def create_fields(self, form, properties):
"""
Creates fields for given form based on given model attributes.
:param form: form to attach the generated fields into
:param attributes: model attributes to generate the form fields from
"""
for key, prop in properties.items():
column = prop.columns[0]
try:
field = self.create_field(prop, column)
except UnknownTypeException:
if not self.meta.skip_unknown_types:
raise
else:
continue
if not hasattr(form, key):
setattr(form, key, field)
[docs] def skip_column_property(self, column_property):
"""
Whether or not to skip column property in the generation process.
:param column_property: SQLAlchemy ColumnProperty object
"""
if column_property._is_polymorphic_discriminator:
return True
return self.skip_column(column_property.columns[0])
[docs] def skip_column(self, column):
"""
Whether or not to skip column in the generation process.
:param column_property: SQLAlchemy Column object
"""
if not self.meta.include_foreign_keys and column.foreign_keys:
return True
if not self.meta.include_primary_keys and column.primary_key:
return True
if (not self.meta.include_datetimes_with_default and
isinstance(column.type, sa.types.DateTime) and
column.default):
return True
if isinstance(column.type, types.TSVectorType):
return True
if self.meta.only_indexed_fields and not self.has_index(column):
return True
# Skip all non columns (this is the case when using column_property
# methods).
if not isinstance(column, sa.Column):
return True
return False
[docs] def has_index(self, column):
"""
Whether or not given column has an index.
:param column: Column object to inspect the indexes from
"""
if column.primary_key or column.foreign_keys:
return True
table = column.table
for index in table.indexes:
if len(index.columns) == 1 and column.name in index.columns:
return True
return False
[docs] def create_field(self, prop, column):
"""
Create form field for given column.
:param prop: SQLAlchemy ColumnProperty object.
:param column: SQLAlchemy Column object.
"""
kwargs = {}
field_class = self.get_field_class(column)
kwargs['default'] = self.default(column)
kwargs['validators'] = self.create_validators(prop, column)
kwargs['filters'] = self.filters(column)
kwargs.update(self.type_agnostic_parameters(prop.key, column))
kwargs.update(self.type_specific_parameters(column))
if prop.key in self.meta.field_args:
kwargs.update(self.meta.field_args[prop.key])
if issubclass(field_class, DecimalField):
if hasattr(column.type, 'scale'):
kwargs['places'] = column.type.scale
field = field_class(**kwargs)
return field
[docs] def default(self, column):
"""
Return field default for given column.
:param column: SQLAlchemy Column object
"""
if column.default and is_scalar(column.default.arg):
return column.default.arg
else:
if not column.nullable:
return self.meta.default
[docs] def filters(self, column):
"""
Return filters for given column.
:param column: SQLAlchemy Column object
"""
should_trim = column.info.get('trim', None)
filters = column.info.get('filters', [])
if (
(
isinstance(column.type, sa.types.String) and
self.meta.strip_string_fields and
should_trim is None
) or
should_trim is True
):
filters.append(strip_string)
return filters
[docs] def date_format(self, column):
"""
Returns date format for given column.
:param column: SQLAlchemy Column object
"""
if (
isinstance(column.type, sa.types.DateTime) or
isinstance(column.type, types.ArrowType)
):
return self.meta.datetime_format
if isinstance(column.type, sa.types.Date):
return self.meta.date_format
[docs] def type_specific_parameters(self, column):
"""
Returns type specific parameters for given column.
:param column: SQLAlchemy Column object
"""
kwargs = {}
if (
hasattr(column.type, 'enums') or
column.info.get('choices') or
isinstance(column.type, types.ChoiceType)
):
kwargs.update(self.select_field_kwargs(column))
date_format = self.date_format(column)
if date_format:
kwargs['format'] = date_format
if hasattr(column.type, 'region'):
kwargs['region'] = column.type.region
kwargs['widget'] = self.widget(column)
return kwargs
[docs] def widget(self, column):
"""
Returns WTForms widget for given column.
:param column: SQLAlchemy Column object
"""
widget = column.info.get('widget', None)
if widget is not None:
return widget
kwargs = {}
step = column.info.get('step', None)
if step is not None:
kwargs['step'] = step
else:
if isinstance(column.type, sa.types.Numeric):
if (
column.type.scale is not None and
not column.info.get('choices')
):
kwargs['step'] = self.scale_to_step(column.type.scale)
if kwargs:
widget_class = self.WIDGET_MAP[
self.get_field_class(column)
]
return widget_class(**kwargs)
[docs] def scale_to_step(self, scale):
"""
Returns HTML5 compatible step attribute for given decimal scale.
:param scale: an integer that defines a Numeric column's scale
"""
return str(pow(Decimal('0.1'), scale))
[docs] def type_agnostic_parameters(self, key, column):
"""
Returns all type agnostic form field parameters for given column.
:param column: SQLAlchemy Column object
"""
kwargs = {}
kwargs['description'] = column.info.get('description', '')
kwargs['label'] = column.info.get('label', key)
return kwargs
[docs] def select_field_kwargs(self, column):
"""
Returns key value args for SelectField based on SQLAlchemy column
definitions.
:param column: SQLAlchemy Column object
"""
kwargs = {}
kwargs['coerce'] = self.coerce(column)
if isinstance(column.type, types.ChoiceType):
choices = column.type.choices
if (
Enum is not None and
isinstance(choices, type)
and issubclass(choices, Enum)
):
kwargs['choices'] = [
(choice.value, str(choice)) for choice in choices
]
else:
kwargs['choices'] = choices
elif 'choices' in column.info and column.info['choices']:
kwargs['choices'] = column.info['choices']
else:
kwargs['choices'] = [
(enum, enum) for enum in column.type.enums
]
return kwargs
[docs] def coerce(self, column):
"""
Returns coerce callable for given column
:param column: SQLAlchemy Column object
"""
if 'coerce' in column.info:
return column.info['coerce']
if isinstance(column.type, types.ChoiceType):
return choice_type_coerce_factory(column.type)
try:
python_type = column.type.python_type
except NotImplementedError:
return null_or_unicode
if column.nullable and issubclass(python_type, six.string_types):
return null_or_unicode
return python_type
[docs] def create_validators(self, prop, column):
"""
Returns validators for given column
:param column: SQLAlchemy Column object
"""
validators = [
self.required_validator(column),
self.length_validator(column),
self.unique_validator(prop.key, column),
self.range_validator(column)
]
if isinstance(column.type, types.EmailType):
validators.append(self.get_validator('email'))
if isinstance(column.type, types.URLType):
validators.append(self.get_validator('url'))
validators = flatten([v for v in validators if v is not None])
validators.extend(self.additional_validators(prop.key, column))
return validators
[docs] def required_validator(self, column):
"""
Returns required / optional validator for given column based on column
nullability and form configuration.
:param column: SQLAlchemy Column object
"""
if (not self.meta.all_fields_optional and
not column.default and
not column.nullable):
type_map = self.meta.not_null_validator_type_map
try:
return type_map[column.type]
except KeyError:
if isinstance(column.type, sa.types.TypeDecorator):
type_ = column.type.impl
try:
return type_map[type_]
except KeyError:
pass
if self.meta.not_null_validator is not None:
return self.meta.not_null_validator
return self.get_validator('optional')
def get_validator(self, name, **kwargs):
attr_name = '%s_validator' % name
attr = getattr(self.meta, attr_name)
if attr is None:
return attr
if inspect.ismethod(attr):
return six.get_unbound_function(attr)(**kwargs)
else:
return attr(**kwargs)
[docs] def additional_validators(self, key, column):
"""
Returns additional validators for given column
:param key: String key of the column property
:param column: SQLAlchemy Column object
"""
validators = []
if key in self.meta.validators:
try:
validators.extend(self.meta.validators[key])
except TypeError:
validators.append(self.meta.validators[key])
if 'validators' in column.info and column.info['validators']:
try:
validators.extend(column.info['validators'])
except TypeError:
validators.append(column.info['validators'])
return validators
[docs] def unique_validator(self, key, column):
"""
Returns unique validator for given column if column has a unique index
:param key: String key of the column property
:param column: SQLAlchemy Column object
"""
if column.unique:
return self.get_validator(
'unique',
column=getattr(self.model_class, key),
get_session=self.form_class.get_session
)
[docs] def range_validator(self, column):
"""
Returns range validator based on column type and column info min and
max arguments
:param column: SQLAlchemy Column object
"""
min_ = column.info.get('min')
max_ = column.info.get('max')
if min_ is not None or max_ is not None:
if is_number(column.type) or is_number_range(column.type):
return self.get_validator('number_range', min=min_, max=max_)
elif is_date_column(column):
return self.get_validator('date_range', min=min_, max=max_)
elif isinstance(column.type, sa.types.Time):
return self.get_validator('time_range', min=min_, max=max_)
[docs] def length_validator(self, column):
"""
Returns length validator for given column
:param column: SQLAlchemy Column object
"""
if (
isinstance(column.type, sa.types.String) and
hasattr(column.type, 'length') and
column.type.length
):
return self.get_validator('length', max=column.type.length)
[docs] def get_field_class(self, column):
"""
Returns WTForms field class. Class is based on a custom field class
attribute or SQLAlchemy column type.
:param column: SQLAlchemy Column object
"""
if (
'form_field_class' in column.info and
column.info['form_field_class']
):
return column.info['form_field_class']
if 'choices' in column.info and column.info['choices']:
return SelectField
if (
column.type not in self.TYPE_MAP and
isinstance(column.type, sa.types.TypeDecorator)
):
check_type = column.type.impl
else:
check_type = column.type
try:
column_type = self.TYPE_MAP[check_type]
if inspect.isclass(column_type) and issubclass(column_type, Field):
return column_type
else:
return column_type(column)
except KeyError:
raise UnknownTypeException(column)