Skip to content

Commit 01e241e

Browse files
zhanyuanucbZhanyuan Zhang
authored and
Suqi Sun
committed
WIP: generator accepts NdArray attribute (asyml#575)
* added static methods to infer pack type * added unit tests for datapack type inference. * fixed issue#558 * fixed black: forte/data/selector.py * fixed import-outside-toplevel * WIP: waiting for ndarray supported * Ndarray -> ndarray * values -> value * added ndarray to SUPPORTED_PRIMITIVES * WIP: generator accepts NdArray attribute * remove breakpoints * fixed None -> "None" * WIP: designing NdArrayProperty * WIP: ndarray_size -> ndarray_shape * fixed lint * fixed lint * added ndarray property test cases * removed irrelevant onto spec * fixed lint * fixed lint * fixed black * fixed description * added FNdArray, a wrapper class for NdArray metric * handle None dtype * removed Optional typing * fixed description * added unit tests for ndarray attribute * fixed black * added doc string and more tests * fixed type * added reference to np.ndarray * added unit tests for ndarray attribute against dtype, shape, and warning Co-authored-by: Zhanyuan Zhang <zhanyuan.zhang@petuum.com>
1 parent 395612b commit 01e241e

15 files changed

+619
-6
lines changed

forte/data/data_pack.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1536,9 +1536,7 @@ def in_span(self, inner_entry: Union[int, Entry], span: Span) -> bool:
15361536
inner_begin = -1
15371537
inner_end = -1
15381538

1539-
if isinstance(inner_entry, Annotation) or isinstance(
1540-
inner_entry, AudioAnnotation
1541-
):
1539+
if isinstance(inner_entry, (Annotation, AudioAnnotation)):
15421540
inner_begin = inner_entry.begin
15431541
inner_end = inner_entry.end
15441542
elif isinstance(inner_entry, Link):

forte/data/ontology/code_generation_objects.py

+46
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from abc import ABC
1818
from pathlib import Path
1919
from typing import Optional, Any, List, Dict, Set, Tuple
20+
from numpy import ndarray
2021

2122
from forte.data.ontology.code_generation_exceptions import (
2223
CodeGenerationException,
@@ -382,6 +383,51 @@ def to_field_value(self):
382383
return self.name
383384

384385

386+
class NdArrayProperty(Property):
387+
"""
388+
NdArrayProperty accepts parsed properties of NdArray and
389+
instructs import manager to import and instanciate FNdArray
390+
as default value in the generated code.
391+
"""
392+
393+
def __init__(
394+
self,
395+
import_manager: ImportManager,
396+
name: str,
397+
ndarray_dtype: Optional[str] = None,
398+
ndarray_shape: Optional[List[int]] = None,
399+
description: Optional[str] = None,
400+
default_val: Optional[ndarray] = None,
401+
):
402+
self.type_str = "forte.data.ontology.core.FNdArray"
403+
super().__init__(
404+
import_manager,
405+
name,
406+
self.type_str,
407+
description=description,
408+
default_val=default_val,
409+
)
410+
self.ndarray_dtype: Optional[str] = ndarray_dtype
411+
self.ndarray_shape: Optional[List[int]] = ndarray_shape
412+
413+
def internal_type_str(self) -> str:
414+
type_str = self.import_manager.get_name_to_use(self.type_str)
415+
return f"{type_str}"
416+
417+
def default_value(self) -> str:
418+
if self.ndarray_dtype is None:
419+
return f"FNdArray(shape={self.ndarray_shape}, dtype={self.ndarray_dtype})"
420+
else:
421+
return f"FNdArray(shape={self.ndarray_shape}, dtype='{self.ndarray_dtype}')"
422+
423+
def _full_class(self):
424+
item_type = self.import_manager.get_name_to_use(self.type_str)
425+
return item_type
426+
427+
def to_field_value(self):
428+
return self.name
429+
430+
385431
class DictProperty(Property):
386432
def __init__(
387433
self,

forte/data/ontology/core.py

+67
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,73 @@ def __iter__(self) -> Iterator[KeyType]:
598598
yield from self.__data
599599

600600

601+
class FNdArray:
602+
"""
603+
FNdArray is a wrapper of a NumPy array that stores shape and data type
604+
of the array if they are specified. Only when both shape and data type
605+
are provided, will FNdArray initialize a placeholder array through
606+
np.ndarray(shape, dtype=dtype).
607+
More details about np.ndarray(...):
608+
https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html
609+
"""
610+
611+
def __init__(
612+
self, dtype: Optional[str] = None, shape: Optional[Iterable[int]] = None
613+
):
614+
super().__init__()
615+
self._dtype: Optional[np.dtype] = (
616+
np.dtype(dtype) if dtype is not None else dtype
617+
)
618+
self._shape: Optional[tuple] = (
619+
tuple(shape) if shape is not None else shape
620+
)
621+
self._data: Optional[np.ndarray] = None
622+
if dtype and shape:
623+
self._data = np.ndarray(shape, dtype=dtype)
624+
625+
@property
626+
def dtype(self):
627+
return self._dtype
628+
629+
@property
630+
def shape(self):
631+
return self._shape
632+
633+
@property
634+
def data(self):
635+
return self._data
636+
637+
@data.setter
638+
def data(self, array: Union[np.ndarray, List]):
639+
if isinstance(array, np.ndarray):
640+
if self.dtype and not np.issubdtype(array.dtype, self.dtype):
641+
raise TypeError(
642+
f"Expecting type or subtype of {self.dtype}, but got {array.dtype}."
643+
)
644+
if self.shape and self.shape != array.shape:
645+
raise AttributeError(
646+
f"Expecting shape {self.shape}, but got {array.shape}."
647+
)
648+
self._data = array
649+
650+
elif isinstance(array, list):
651+
array_np = np.array(array, dtype=self.dtype)
652+
if self.shape and self.shape != array_np.shape:
653+
raise AttributeError(
654+
f"Expecting shape {self.shape}, but got {array_np.shape}."
655+
)
656+
self._data = array_np
657+
658+
else:
659+
raise ValueError(
660+
f"Can only accept numpy array or python list, but got {type(array)}"
661+
)
662+
663+
# Stored dtype and shape should match to the provided array's.
664+
self._dtype = self._data.dtype
665+
self._shape = self._data.shape
666+
667+
601668
class Pointer(BasePointer):
602669
"""
603670
A pointer that points to an entry in the current pack, this is basically

forte/data/ontology/ontology_code_const.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class SchemaKeywords:
2828
element_type = "item_type"
2929
dict_key_type = "key_type"
3030
dict_value_type = "value_type"
31+
ndarray_dtype = "ndarray_dtype"
32+
ndarray_shape = "ndarray_shape"
3133

3234

3335
# Some names are used as properties by the core types, they should not be
@@ -83,7 +85,7 @@ def get_ignore_error_lines(json_filepath: str) -> List[str]:
8385

8486
SUPPORTED_PRIMITIVES = {"int", "float", "str", "bool"}
8587
NON_COMPOSITES = {key: key for key in SUPPORTED_PRIMITIVES}
86-
COMPOSITES = {"List", "Dict"}
88+
COMPOSITES = {"List", "Dict", "NdArray"}
8789

8890
ALL_INBUILT_TYPES = set(list(NON_COMPOSITES.keys()) + list(COMPOSITES))
8991

forte/data/ontology/ontology_code_generator.py

+38
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import jsonschema
3333
import typed_ast.ast3 as ast
3434
import typed_astunparse as ast_unparse
35+
from numpy import ndarray
3536

3637
from forte.data.ontology import top, utils
3738
from forte.data.ontology.code_generation_exceptions import (
@@ -48,6 +49,7 @@
4849
OntologySourceNotFoundException,
4950
)
5051
from forte.data.ontology.code_generation_objects import (
52+
NdArrayProperty,
5153
NonCompositeProperty,
5254
ListProperty,
5355
ClassTypeDefinition,
@@ -1078,6 +1080,40 @@ def parse_entry(
10781080

10791081
return entry_item, property_names
10801082

1083+
def parse_ndarray(
1084+
self,
1085+
manager: ImportManager,
1086+
schema: Dict,
1087+
att_name: str,
1088+
desc: str,
1089+
):
1090+
ndarray_dtype = None
1091+
if SchemaKeywords.ndarray_dtype in schema:
1092+
ndarray_dtype = schema[SchemaKeywords.ndarray_dtype]
1093+
1094+
ndarray_shape = None
1095+
if SchemaKeywords.ndarray_shape in schema:
1096+
ndarray_shape = schema[SchemaKeywords.ndarray_shape]
1097+
1098+
if ndarray_dtype is None or ndarray_shape is None:
1099+
warnings.warn(
1100+
"Either dtype or shape is not specified."
1101+
" It is recommended to specify both of them."
1102+
)
1103+
1104+
default_val = None
1105+
if ndarray_dtype and ndarray_shape:
1106+
default_val = ndarray(ndarray_shape, dtype=ndarray_dtype)
1107+
1108+
return NdArrayProperty(
1109+
manager,
1110+
att_name,
1111+
ndarray_dtype,
1112+
ndarray_shape,
1113+
description=desc,
1114+
default_val=default_val,
1115+
)
1116+
10811117
def parse_dict(
10821118
self,
10831119
manager: ImportManager,
@@ -1250,6 +1286,8 @@ def parse_property(self, entry_name: EntryName, schema: Dict) -> Property:
12501286
return self.parse_dict(
12511287
manager, schema, entry_name, att_name, att_type, desc
12521288
)
1289+
elif att_type == "NdArray":
1290+
return self.parse_ndarray(manager, schema, att_name, desc)
12531291
elif att_type in NON_COMPOSITES or manager.is_imported(att_type):
12541292
self_ref = entry_name.class_name == att_type
12551293
return self.parse_non_composite(

forte/data/ontology/validation_schema.json

+45
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,34 @@
9595
"value_type": {
9696
"description": "Item type for the case of Dice attributes",
9797
"type": "string"
98+
},
99+
"ndarray_dtype": {
100+
"description": "Data type for the case of NdArray attributes. Allow a subset of NumPy supported data types",
101+
"type": "string",
102+
"enum": [
103+
"bool",
104+
"bool8",
105+
"int",
106+
"int8",
107+
"int32",
108+
"int64",
109+
"uint8",
110+
"uint32",
111+
"uint64",
112+
"float",
113+
"float32",
114+
"float64",
115+
"float96",
116+
"float128",
117+
"complex",
118+
"complex128",
119+
"complex192",
120+
"complex256"
121+
]
122+
},
123+
"ndarray_shape": {
124+
"description": "Shape of N-dimensional array for the case of NdArray attributes",
125+
"type": "array"
98126
}
99127
},
100128
"anyOf": [
@@ -135,6 +163,23 @@
135163
}
136164
]
137165
},
166+
{
167+
"allOf": [
168+
{
169+
"properties": {
170+
"name": {
171+
"enum": [
172+
"NdArray"
173+
]
174+
}
175+
},
176+
"required": [
177+
"name",
178+
"type"
179+
]
180+
}
181+
]
182+
},
138183
{
139184
"allOf": [
140185
{

tests/forte/data/audio_annotation_test.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def _process(self, input_pack: DataPack):
3636
pack=input_pack, begin=0, end=len(input_pack.audio)
3737
)
3838

39+
3940
class AudioUtteranceProcessor(PackProcessor):
4041
"""
4142
A processor to add an AudioUtterance annotation to the specified span of
@@ -91,7 +92,7 @@ def test_audio_annotation(self):
9192

9293
# Verify the annotations of each datapack
9394
for pack in self._pipeline.process_dataset(self._test_audio_path):
94-
95+
9596
# Check Recording
9697
recordings = list(pack.get(Recording))
9798
self.assertEqual(len(recordings), 1)
@@ -100,7 +101,7 @@ def test_audio_annotation(self):
100101
# Check total number of AudioAnnotations which should be 3
101102
# (1 Recording + 2 AudioUtterance).
102103
self.assertEqual(pack.num_audio_annotations, 3)
103-
104+
104105
# Check `DataPack.get(AudioUtterance)` and
105106
# `AudioAnnotation.get(AudioUtterance)`
106107
for object in (pack, recordings[0]):

tests/forte/data/ontology/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)