11import os
22from functools import partial
3- from typing import Callable , Dict , List , Optional , Tuple , Type
3+ from typing import Callable , Dict , List , Optional , Tuple
44
55import toml
6+ from django .db .models .fields .related import RelatedField
67from mypy .nodes import MypyFile , TypeInfo
78from mypy .options import Options
8- from mypy .plugin import ClassDefContext , FunctionContext , Plugin , MethodContext , AttributeContext
9+ from mypy .plugin import AttributeContext , ClassDefContext , FunctionContext , MethodContext , Plugin
910from mypy .types import Type as MypyType
1011
11- from django .db .models .fields .related import RelatedField
1212from mypy_django_plugin_newsemanal .django .context import DjangoContext
13- from mypy_django_plugin_newsemanal .lib import fullnames , metadata
14- from mypy_django_plugin_newsemanal .transformers import fields , settings , querysets , init_create
13+ from mypy_django_plugin_newsemanal .lib import fullnames , helpers
14+ from mypy_django_plugin_newsemanal .transformers import fields , forms , init_create , querysets , settings
1515from mypy_django_plugin_newsemanal .transformers .models import process_model_class
1616
1717
@@ -20,7 +20,7 @@ def transform_model_class(ctx: ClassDefContext,
2020 sym = ctx .api .lookup_fully_qualified_or_none (fullnames .MODEL_CLASS_FULLNAME )
2121
2222 if sym is not None and isinstance (sym .node , TypeInfo ):
23- metadata .get_django_metadata (sym .node )['model_bases' ][ctx .cls .fullname ] = 1
23+ helpers .get_django_metadata (sym .node )['model_bases' ][ctx .cls .fullname ] = 1
2424 else :
2525 if not ctx .api .final_iteration :
2626 ctx .api .defer ()
@@ -29,10 +29,18 @@ def transform_model_class(ctx: ClassDefContext,
2929 process_model_class (ctx , django_context )
3030
3131
32+ def transform_form_class (ctx : ClassDefContext ) -> None :
33+ sym = ctx .api .lookup_fully_qualified_or_none (fullnames .BASEFORM_CLASS_FULLNAME )
34+ if sym is not None and isinstance (sym .node , TypeInfo ):
35+ helpers .get_django_metadata (sym .node )['baseform_bases' ][ctx .cls .fullname ] = 1
36+
37+ forms .make_meta_nested_class_inherit_from_any (ctx )
38+
39+
3240def add_new_manager_base (ctx : ClassDefContext ) -> None :
3341 sym = ctx .api .lookup_fully_qualified_or_none (fullnames .MANAGER_CLASS_FULLNAME )
3442 if sym is not None and isinstance (sym .node , TypeInfo ):
35- metadata .get_django_metadata (sym .node )['manager_bases' ][ctx .cls .fullname ] = 1
43+ helpers .get_django_metadata (sym .node )['manager_bases' ][ctx .cls .fullname ] = 1
3644
3745
3846class NewSemanalDjangoPlugin (Plugin ):
@@ -50,24 +58,34 @@ def __init__(self, options: Options) -> None:
5058 def _get_current_queryset_bases (self ) -> Dict [str , int ]:
5159 model_sym = self .lookup_fully_qualified (fullnames .QUERYSET_CLASS_FULLNAME )
5260 if model_sym is not None and isinstance (model_sym .node , TypeInfo ):
53- return (metadata .get_django_metadata (model_sym .node )
61+ return (helpers .get_django_metadata (model_sym .node )
5462 .setdefault ('queryset_bases' , {fullnames .QUERYSET_CLASS_FULLNAME : 1 }))
5563 else :
5664 return {}
5765
5866 def _get_current_manager_bases (self ) -> Dict [str , int ]:
5967 model_sym = self .lookup_fully_qualified (fullnames .MANAGER_CLASS_FULLNAME )
6068 if model_sym is not None and isinstance (model_sym .node , TypeInfo ):
61- return (metadata .get_django_metadata (model_sym .node )
69+ return (helpers .get_django_metadata (model_sym .node )
6270 .setdefault ('manager_bases' , {fullnames .MANAGER_CLASS_FULLNAME : 1 }))
6371 else :
6472 return {}
6573
6674 def _get_current_model_bases (self ) -> Dict [str , int ]:
6775 model_sym = self .lookup_fully_qualified (fullnames .MODEL_CLASS_FULLNAME )
6876 if model_sym is not None and isinstance (model_sym .node , TypeInfo ):
69- return metadata .get_django_metadata (model_sym .node ).setdefault ('model_bases' ,
70- {fullnames .MODEL_CLASS_FULLNAME : 1 })
77+ return helpers .get_django_metadata (model_sym .node ).setdefault ('model_bases' ,
78+ {fullnames .MODEL_CLASS_FULLNAME : 1 })
79+ else :
80+ return {}
81+
82+ def _get_current_form_bases (self ) -> Dict [str , int ]:
83+ model_sym = self .lookup_fully_qualified (fullnames .BASEFORM_CLASS_FULLNAME )
84+ if model_sym is not None and isinstance (model_sym .node , TypeInfo ):
85+ return (helpers .get_django_metadata (model_sym .node )
86+ .setdefault ('baseform_bases' , {fullnames .BASEFORM_CLASS_FULLNAME : 1 ,
87+ fullnames .FORM_CLASS_FULLNAME : 1 ,
88+ fullnames .MODELFORM_CLASS_FULLNAME : 1 }))
7189 else :
7290 return {}
7391
@@ -85,15 +103,20 @@ def get_additional_deps(self, file: MypyFile) -> List[Tuple[int, str, int]]:
85103 if file .fullname () == 'django.conf' and self .django_context .django_settings_module :
86104 return [self ._new_dependency (self .django_context .django_settings_module )]
87105
106+ # for values / values_list
107+ if file .fullname () == 'django.db.models' :
108+ return [self ._new_dependency ('mypy_extensions' ), self ._new_dependency ('typing' )]
109+
88110 # for `get_user_model()`
89- if file .fullname () == 'django.contrib.auth' :
90- auth_user_model_name = self .django_context .settings .AUTH_USER_MODEL
91- try :
92- auth_user_module = self .django_context .apps_registry .get_model (auth_user_model_name ).__module__
93- except LookupError :
94- # get_user_model() model app is not installed
95- return []
96- return [self ._new_dependency (auth_user_module )]
111+ if self .django_context .settings :
112+ if file .fullname () == 'django.contrib.auth' :
113+ auth_user_model_name = self .django_context .settings .AUTH_USER_MODEL
114+ try :
115+ auth_user_module = self .django_context .apps_registry .get_model (auth_user_model_name ).__module__
116+ except LookupError :
117+ # get_user_model() model app is not installed
118+ return []
119+ return [self ._new_dependency (auth_user_module )]
97120
98121 # ensure that all mentioned to='someapp.SomeModel' are loaded with corresponding related Fields
99122 defined_model_classes = self .django_context .model_modules .get (file .fullname ())
@@ -132,9 +155,29 @@ def get_function_hook(self, fullname: str
132155 return partial (init_create .redefine_and_typecheck_model_init , django_context = self .django_context )
133156
134157 def get_method_hook (self , fullname : str
135- ) -> Optional [Callable [[MethodContext ], Type ]]:
136- manager_classes = self ._get_current_manager_bases ()
158+ ) -> Optional [Callable [[MethodContext ], MypyType ]]:
137159 class_fullname , _ , method_name = fullname .rpartition ('.' )
160+ if method_name == 'get_form_class' :
161+ info = self ._get_typeinfo_or_none (class_fullname )
162+ if info and info .has_base (fullnames .FORM_MIXIN_CLASS_FULLNAME ):
163+ return forms .extract_proper_type_for_get_form_class
164+
165+ if method_name == 'get_form' :
166+ info = self ._get_typeinfo_or_none (class_fullname )
167+ if info and info .has_base (fullnames .FORM_MIXIN_CLASS_FULLNAME ):
168+ return forms .extract_proper_type_for_get_form
169+
170+ if method_name == 'values' :
171+ model_info = self ._get_typeinfo_or_none (class_fullname )
172+ if model_info and model_info .has_base (fullnames .QUERYSET_CLASS_FULLNAME ):
173+ return partial (querysets .extract_proper_type_queryset_values , django_context = self .django_context )
174+
175+ if method_name == 'values_list' :
176+ model_info = self ._get_typeinfo_or_none (class_fullname )
177+ if model_info and model_info .has_base (fullnames .QUERYSET_CLASS_FULLNAME ):
178+ return partial (querysets .extract_proper_type_queryset_values_list , django_context = self .django_context )
179+
180+ manager_classes = self ._get_current_manager_bases ()
138181 if class_fullname in manager_classes and method_name == 'create' :
139182 return partial (init_create .redefine_and_typecheck_model_create , django_context = self .django_context )
140183
@@ -146,19 +189,16 @@ def get_base_class_hook(self, fullname: str
146189 if fullname in self ._get_current_manager_bases ():
147190 return add_new_manager_base
148191
192+ if fullname in self ._get_current_form_bases ():
193+ return transform_form_class
194+
149195 def get_attribute_hook (self , fullname : str
150196 ) -> Optional [Callable [[AttributeContext ], MypyType ]]:
151197 class_name , _ , attr_name = fullname .rpartition ('.' )
152198 if class_name == fullnames .DUMMY_SETTINGS_BASE_CLASS :
153199 return partial (settings .get_type_of_settings_attribute ,
154200 django_context = self .django_context )
155201
156- # def get_type_analyze_hook(self, fullname: str
157- # ) -> Optional[Callable[[AnalyzeTypeContext], MypyType]]:
158- # queryset_bases = self._get_current_queryset_bases()
159- # if fullname in queryset_bases:
160- # return partial(querysets.set_first_generic_param_as_default_for_second, fullname=fullname)
161-
162202
163203def plugin (version ):
164204 return NewSemanalDjangoPlugin
0 commit comments