2020class MultiModalHasher :
2121
2222 @classmethod
23- def serialize_item (cls , obj : object ) -> Union [bytes , memoryview ]:
23+ def serialize_item (cls , obj : object ) -> Iterable [ Union [bytes , memoryview ] ]:
2424 # Simple cases
25- if isinstance (obj , str ):
26- return obj .encode ("utf-8" )
2725 if isinstance (obj , (bytes , memoryview )):
28- return obj
26+ return (obj , )
27+ if isinstance (obj , str ):
28+ return (obj .encode ("utf-8" ), )
2929 if isinstance (obj , (int , float )):
30- return np .array (obj ).tobytes ()
30+ return ( np .array (obj ).tobytes (), )
3131
3232 if isinstance (obj , Image .Image ):
3333 exif = obj .getexif ()
3434 if Image .ExifTags .Base .ImageID in exif and isinstance (
3535 exif [Image .ExifTags .Base .ImageID ], uuid .UUID ):
3636 # If the image has exif ImageID tag, use that
37- return exif [Image .ExifTags .Base .ImageID ].bytes
38- return cls .item_to_bytes (
37+ return ( exif [Image .ExifTags .Base .ImageID ].bytes , )
38+ return cls .iter_item_to_bytes (
3939 "image" , np .asarray (convert_image_mode (obj , "RGBA" )))
4040 if isinstance (obj , torch .Tensor ):
4141 tensor_obj : torch .Tensor = obj .cpu ()
@@ -49,43 +49,34 @@ def serialize_item(cls, obj: object) -> Union[bytes, memoryview]:
4949 tensor_obj = tensor_obj .view (
5050 (tensor_obj .numel (), )).view (torch .uint8 )
5151
52- return cls .item_to_bytes (
52+ return cls .iter_item_to_bytes (
5353 "tensor" , {
5454 "original_dtype" : str (tensor_dtype ),
5555 "original_shape" : tuple (tensor_shape ),
5656 "data" : tensor_obj .numpy (),
5757 })
58-
59- return cls .item_to_bytes ("tensor" , tensor_obj .numpy ())
58+ return cls .iter_item_to_bytes ("tensor" , tensor_obj .numpy ())
6059 if isinstance (obj , np .ndarray ):
6160 # If the array is non-contiguous, we need to copy it first
62- arr_data = obj .data if obj .flags .c_contiguous else obj .tobytes ()
63- return cls .item_to_bytes ("ndarray" , {
61+ arr_data = obj .view (
62+ np .uint8 ).data if obj .flags .c_contiguous else obj .tobytes ()
63+ return cls .iter_item_to_bytes ("ndarray" , {
6464 "dtype" : obj .dtype .str ,
6565 "shape" : obj .shape ,
6666 "data" : arr_data ,
6767 })
68-
6968 logger .warning (
7069 "No serialization method found for %s. "
7170 "Falling back to pickle." , type (obj ))
7271
73- return pickle .dumps (obj )
74-
75- @classmethod
76- def item_to_bytes (
77- cls ,
78- key : str ,
79- obj : object ,
80- ) -> bytes :
81- return b'' .join (kb + vb for kb , vb in cls .iter_item_to_bytes (key , obj ))
72+ return (pickle .dumps (obj ), )
8273
8374 @classmethod
8475 def iter_item_to_bytes (
8576 cls ,
8677 key : str ,
8778 obj : object ,
88- ) -> Iterable [tuple [ bytes , Union [bytes , memoryview ] ]]:
79+ ) -> Iterable [Union [bytes , memoryview ]]:
8980 # Recursive cases
9081 if isinstance (obj , (list , tuple )):
9182 for i , elem in enumerate (obj ):
@@ -94,17 +85,15 @@ def iter_item_to_bytes(
9485 for k , v in obj .items ():
9586 yield from cls .iter_item_to_bytes (f"{ key } .{ k } " , v )
9687 else :
97- key_bytes = key .encode ("utf-8" )
98- value_bytes = cls .serialize_item (obj )
99- yield key_bytes , value_bytes
88+ yield key .encode ("utf-8" )
89+ yield from cls .serialize_item (obj )
10090
10191 @classmethod
10292 def hash_kwargs (cls , ** kwargs : object ) -> str :
10393 hasher = blake3 ()
10494
10595 for k , v in kwargs .items ():
106- for k_bytes , v_bytes in cls .iter_item_to_bytes (k , v ):
107- hasher .update (k_bytes )
108- hasher .update (v_bytes )
96+ for bytes_ in cls .iter_item_to_bytes (k , v ):
97+ hasher .update (bytes_ )
10998
11099 return hasher .hexdigest ()
0 commit comments