Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 76 additions & 6 deletions src/idl_gen_python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,63 @@ class PythonStubGenerator {
}
}

void GenerateObjectInitializerStub(std::stringstream &stub,
const StructDef *struct_def,
Imports *imports) const {
stub << " def __init__(\n";
stub << " self,\n";

for (const FieldDef *field : struct_def->fields.vec) {
if (field->deprecated) continue;

std::string field_name = namer_.Field(*field);
std::string field_type;
const Type &type = field->value.type;

if (IsScalar(type.base_type)) {
field_type = TypeOf(type, imports);
if (field->IsOptional()) { field_type += " | None"; }
} else {
switch (type.base_type) {
case BASE_TYPE_STRUCT: {
Import import_ =
imports->Import(ModuleFor(type.struct_def),
namer_.ObjectType(*type.struct_def));
field_type = "'" + import_.name + "' | None";
break;
}
case BASE_TYPE_STRING:
field_type = "str | None";
break;
case BASE_TYPE_ARRAY:
case BASE_TYPE_VECTOR: {
imports->Import("typing");
if (type.element == BASE_TYPE_STRUCT) {
Import import_ =
imports->Import(ModuleFor(type.struct_def),
namer_.ObjectType(*type.struct_def));
field_type = "typing.List['" + import_.name + "'] | None";
} else if (type.element == BASE_TYPE_STRING) {
field_type = "typing.List[str] | None";
} else {
field_type = "typing.List[" + TypeOf(type.VectorType(), imports) +
"] | None";
}
break;
}
case BASE_TYPE_UNION:
field_type = UnionObjectType(*type.enum_def, imports);
break;
default:
field_type = "typing.Any";
break;
}
}
stub << " " << field_name << ": " << field_type << " = ...,\n";
}
stub << " ) -> None: ...\n";
}

void GenerateObjectStub(std::stringstream &stub, const StructDef *struct_def,
Imports *imports) const {
std::string name = namer_.ObjectType(*struct_def);
Expand All @@ -299,6 +356,8 @@ class PythonStubGenerator {
stub << " " << GenerateObjectFieldStub(field, imports) << "\n";
}

GenerateObjectInitializerStub(stub, struct_def, imports);

stub << " @classmethod\n";
stub << " def InitFromBuf(cls, buf: bytes, pos: int) -> " << name
<< ": ...\n";
Expand Down Expand Up @@ -1674,6 +1733,7 @@ class PythonGenerator : public BaseGenerator {
field_type = package_reference + "." + field_type;
import_list->insert("import " + package_reference);
}
field_type = "'" + field_type + "'";
break;
case BASE_TYPE_STRING: field_type += "str"; break;
case BASE_TYPE_NONE: field_type += "None"; break;
Expand Down Expand Up @@ -1735,8 +1795,12 @@ class PythonGenerator : public BaseGenerator {

void GenInitialize(const StructDef &struct_def, std::string *code_ptr,
std::set<std::string> *import_list) const {
std::string code;
std::string signature_params;
std::string init_body;
std::set<std::string> import_typing_list;

signature_params += GenIndents(2) + "self,";

for (auto it = struct_def.fields.vec.begin();
it != struct_def.fields.vec.end(); ++it) {
auto &field = **it;
Expand All @@ -1763,6 +1827,7 @@ class PythonGenerator : public BaseGenerator {
// Scalar or sting fields.
field_type = GetBasePythonTypeForScalarAndString(base_type);
if (field.IsScalarOptional()) {
import_typing_list.insert("Optional");
field_type = "Optional[" + field_type + "]";
}
break;
Expand All @@ -1771,18 +1836,23 @@ class PythonGenerator : public BaseGenerator {
const auto default_value = GetDefaultValue(field);
// Writes the init statement.
const auto field_field = namer_.Field(field);
code += GenIndents(2) + "self." + field_field + " = " + default_value +
" # type: " + field_type;

// Build signature with keyword arguments, type hints, and default values.
signature_params += GenIndents(2) + field_field + " = " + default_value + ",";

// Build the body of the __init__ method.
init_body += GenIndents(2) + "self." + field_field + " = " + field_field +
" # type: " + field_type;
}

// Writes __init__ method.
auto &code_base = *code_ptr;
GenReceiverForObjectAPI(struct_def, code_ptr);
code_base += "__init__(self):";
if (code.empty()) {
code_base += "__init__(" + signature_params + GenIndents(1) + "):";
if (init_body.empty()) {
code_base += GenIndents(2) + "pass";
} else {
code_base += code;
code_base += init_body;
}
code_base += "\n";

Expand Down
10 changes: 7 additions & 3 deletions tests/MyGame/Example/Ability.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ def CreateAbility(builder, id, distance):
class AbilityT(object):

# AbilityT
def __init__(self):
self.id = 0 # type: int
self.distance = 0 # type: int
def __init__(
self,
id = 0,
distance = 0,
):
self.id = id # type: int
self.distance = distance # type: int

@classmethod
def InitFromBuf(cls, buf, pos):
Expand Down
22 changes: 15 additions & 7 deletions tests/MyGame/Example/ArrayStruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,21 @@ def CreateArrayStruct(builder, a, b, c, d_a, d_b, d_c, d_d, e, f):
class ArrayStructT(object):

# ArrayStructT
def __init__(self):
self.a = 0.0 # type: float
self.b = None # type: Optional[List[int]]
self.c = 0 # type: int
self.d = None # type: Optional[List[MyGame.Example.NestedStruct.NestedStructT]]
self.e = 0 # type: int
self.f = None # type: Optional[List[int]]
def __init__(
self,
a = 0.0,
b = None,
c = 0,
d = None,
e = 0,
f = None,
):
self.a = a # type: float
self.b = b # type: Optional[List[int]]
self.c = c # type: int
self.d = d # type: Optional[List[MyGame.Example.NestedStruct.NestedStructT]]
self.e = e # type: int
self.f = f # type: Optional[List[int]]

@classmethod
def InitFromBuf(cls, buf, pos):
Expand Down
9 changes: 9 additions & 0 deletions tests/MyGame/Example/ArrayStruct.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ class ArrayStructT(object):
d: typing.List[NestedStructT]
e: int
f: typing.List[int]
def __init__(
self,
a: float = ...,
b: typing.List[int] | None = ...,
c: int = ...,
d: typing.List['NestedStructT'] | None = ...,
e: int = ...,
f: typing.List[int] | None = ...,
) -> None: ...
@classmethod
def InitFromBuf(cls, buf: bytes, pos: int) -> ArrayStructT: ...
@classmethod
Expand Down
7 changes: 5 additions & 2 deletions tests/MyGame/Example/ArrayTable.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,11 @@ def End(builder: flatbuffers.Builder) -> int:
class ArrayTableT(object):

# ArrayTableT
def __init__(self):
self.a = None # type: Optional[MyGame.Example.ArrayStruct.ArrayStructT]
def __init__(
self,
a = None,
):
self.a = a # type: Optional[MyGame.Example.ArrayStruct.ArrayStructT]

@classmethod
def InitFromBuf(cls, buf, pos):
Expand Down
4 changes: 4 additions & 0 deletions tests/MyGame/Example/ArrayTable.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class ArrayTable(object):
def A(self) -> ArrayStruct | None: ...
class ArrayTableT(object):
a: ArrayStructT | None
def __init__(
self,
a: 'ArrayStructT' | None = ...,
) -> None: ...
@classmethod
def InitFromBuf(cls, buf: bytes, pos: int) -> ArrayTableT: ...
@classmethod
Expand Down
Loading
Loading