@@ -69,19 +69,20 @@ namespace runtime {
69
69
TVM_DLL std::string GetCustomTypeName (uint8_t type_code);
70
70
71
71
/* !
72
- * \brief Runtime utility for getting custom type code from name
73
- * \param type_name Custom type name
74
- * \return Custom type code
75
- */
76
- TVM_DLL uint8_t GetCustomTypeCode (const std::string& type_name);
77
-
78
- /* !
79
- * \brief Runtime utility for checking whether custom type is registered
80
- * \param type_code Custom type code
81
- * \return Bool representing whether type is registered
82
- */
72
+ * \brief Runtime utility for checking whether custom type is registered
73
+ * \param type_code Custom type code
74
+ * \return Bool representing whether type is registered
75
+ */
83
76
TVM_DLL bool GetCustomTypeRegistered (uint8_t type_code);
84
77
78
+ /* !
79
+ * \brief Runtime utility for parsing string of the form "custom[<typename>]"
80
+ * \param s String to parse
81
+ * \param scan pointer to parsing pointer, which is scanning across s
82
+ * \return type code of custom type parsed
83
+ */
84
+ TVM_DLL uint8_t ParseCustomDatatype (const std::string& s, const char ** scan);
85
+
85
86
// forward declarations
86
87
class TVMArgs ;
87
88
class TVMArgValue ;
@@ -1025,22 +1026,7 @@ inline TVMType String2TVMType(std::string s) {
1025
1026
t.lanes = 1 ;
1026
1027
return t;
1027
1028
} else if (s.substr (0 , 6 ) == " custom" ) {
1028
- // TODO(gus) this should be separated out into its own parsing function and cleaned up, or
1029
- // replaced by a regex.
1030
- scan = s.c_str () + 6 ;
1031
- if (*scan != ' [' )
1032
- LOG (FATAL) << " expected opening brace after 'custom' type in" << s;
1033
- ++scan;
1034
- size_t custom_name_len = 0 ;
1035
- while (scan + custom_name_len <= s.c_str () + s.length () &&
1036
- *(scan + custom_name_len) != ' ]' )
1037
- ++custom_name_len;
1038
- if (*(scan + custom_name_len) != ' ]' )
1039
- LOG (FATAL) << " expected closing brace after 'custom' type in" << s;
1040
- scan += custom_name_len + 1 ;
1041
-
1042
- auto type_name = s.substr (7 , custom_name_len);
1043
- t.code = GetCustomTypeCode (type_name);
1029
+ t.code = ParseCustomDatatype (s, &scan);
1044
1030
} else {
1045
1031
scan = s.c_str ();
1046
1032
LOG (FATAL) << " unknown type " << s;
0 commit comments