99import static org .opensearch .ml .common .CommonValue .VERSION_3_1_0 ;
1010
1111import java .io .IOException ;
12+ import java .util .Locale ;
1213import java .util .Map ;
1314import java .util .Set ;
1415import java .util .stream .Collectors ;
@@ -41,20 +42,60 @@ public class BaseModelConfig extends MLModelConfig {
4142 it -> parse (it )
4243 );
4344
45+ public static final String EMBEDDING_DIMENSION_FIELD = "embedding_dimension" ;
46+ public static final String FRAMEWORK_TYPE_FIELD = "framework_type" ;
47+ public static final String POOLING_MODE_FIELD = "pooling_mode" ;
48+ public static final String NORMALIZE_RESULT_FIELD = "normalize_result" ;
49+ public static final String MODEL_MAX_LENGTH_FIELD = "model_max_length" ;
50+ public static final String QUERY_PREFIX = "query_prefix" ;
51+ public static final String PASSAGE_PREFIX = "passage_prefix" ;
4452 public static final String ADDITIONAL_CONFIG_FIELD = "additional_config" ;
53+
54+ protected Integer embeddingDimension ;
55+ protected FrameworkType frameworkType ;
56+ protected PoolingMode poolingMode ;
57+ protected boolean normalizeResult ;
58+ protected Integer modelMaxLength ;
59+ protected String queryPrefix ;
60+ protected String passagePrefix ;
4561 protected Map <String , Object > additionalConfig ;
4662
4763 @ Builder (builderMethodName = "baseModelConfigBuilder" )
48- public BaseModelConfig (String modelType , String allConfig , Map <String , Object > additionalConfig ) {
64+ public BaseModelConfig (
65+ String modelType ,
66+ String allConfig ,
67+ Map <String , Object > additionalConfig ,
68+ Integer embeddingDimension ,
69+ FrameworkType frameworkType ,
70+ PoolingMode poolingMode ,
71+ boolean normalizeResult ,
72+ Integer modelMaxLength ,
73+ String queryPrefix ,
74+ String passagePrefix
75+ ) {
4976 super (modelType , allConfig );
5077 this .additionalConfig = additionalConfig ;
78+ this .embeddingDimension = embeddingDimension ;
79+ this .frameworkType = frameworkType ;
80+ this .poolingMode = poolingMode ;
81+ this .normalizeResult = normalizeResult ;
82+ this .modelMaxLength = modelMaxLength ;
83+ this .queryPrefix = queryPrefix ;
84+ this .passagePrefix = passagePrefix ;
5185 validateNoDuplicateKeys (allConfig , additionalConfig );
5286 }
5387
5488 public static BaseModelConfig parse (XContentParser parser ) throws IOException {
5589 String modelType = null ;
5690 String allConfig = null ;
5791 Map <String , Object > additionalConfig = null ;
92+ Integer embeddingDimension = null ;
93+ FrameworkType frameworkType = null ;
94+ PoolingMode poolingMode = null ;
95+ boolean normalizeResult = false ;
96+ Integer modelMaxLength = null ;
97+ String queryPrefix = null ;
98+ String passagePrefix = null ;
5899
59100 ensureExpectedToken (XContentParser .Token .START_OBJECT , parser .currentToken (), parser );
60101 while (parser .nextToken () != XContentParser .Token .END_OBJECT ) {
@@ -71,12 +112,44 @@ public static BaseModelConfig parse(XContentParser parser) throws IOException {
71112 case ADDITIONAL_CONFIG_FIELD :
72113 additionalConfig = parser .map ();
73114 break ;
115+ case EMBEDDING_DIMENSION_FIELD :
116+ embeddingDimension = parser .intValue ();
117+ break ;
118+ case FRAMEWORK_TYPE_FIELD :
119+ frameworkType = FrameworkType .from (parser .text ().toUpperCase (Locale .ROOT ));
120+ break ;
121+ case POOLING_MODE_FIELD :
122+ poolingMode = PoolingMode .from (parser .text ().toUpperCase (Locale .ROOT ));
123+ break ;
124+ case NORMALIZE_RESULT_FIELD :
125+ normalizeResult = parser .booleanValue ();
126+ break ;
127+ case MODEL_MAX_LENGTH_FIELD :
128+ modelMaxLength = parser .intValue ();
129+ break ;
130+ case QUERY_PREFIX :
131+ queryPrefix = parser .text ();
132+ break ;
133+ case PASSAGE_PREFIX :
134+ passagePrefix = parser .text ();
135+ break ;
74136 default :
75137 parser .skipChildren ();
76138 break ;
77139 }
78140 }
79- return new BaseModelConfig (modelType , allConfig , additionalConfig );
141+ return new BaseModelConfig (
142+ modelType ,
143+ allConfig ,
144+ additionalConfig ,
145+ embeddingDimension ,
146+ frameworkType ,
147+ poolingMode ,
148+ normalizeResult ,
149+ modelMaxLength ,
150+ queryPrefix ,
151+ passagePrefix
152+ );
80153 }
81154
82155 @ Override
@@ -89,6 +162,21 @@ public BaseModelConfig(StreamInput in) throws IOException {
89162 if (in .getVersion ().onOrAfter (VERSION_3_1_0 )) {
90163 this .additionalConfig = in .readMap ();
91164 }
165+ embeddingDimension = in .readOptionalInt ();
166+ if (in .readBoolean ()) {
167+ frameworkType = in .readEnum (FrameworkType .class );
168+ } else {
169+ frameworkType = null ;
170+ }
171+ if (in .readBoolean ()) {
172+ poolingMode = in .readEnum (PoolingMode .class );
173+ } else {
174+ poolingMode = null ;
175+ }
176+ normalizeResult = in .readBoolean ();
177+ modelMaxLength = in .readOptionalInt ();
178+ queryPrefix = in .readOptionalString ();
179+ passagePrefix = in .readOptionalString ();
92180 }
93181
94182 @ Override
@@ -97,6 +185,23 @@ public void writeTo(StreamOutput out) throws IOException {
97185 if (out .getVersion ().onOrAfter (VERSION_3_1_0 )) {
98186 out .writeMap (additionalConfig );
99187 }
188+ out .writeOptionalInt (embeddingDimension );
189+ if (frameworkType != null ) {
190+ out .writeBoolean (true );
191+ out .writeEnum (frameworkType );
192+ } else {
193+ out .writeBoolean (false );
194+ }
195+ if (poolingMode != null ) {
196+ out .writeBoolean (true );
197+ out .writeEnum (poolingMode );
198+ } else {
199+ out .writeBoolean (false );
200+ }
201+ out .writeBoolean (normalizeResult );
202+ out .writeOptionalInt (modelMaxLength );
203+ out .writeOptionalString (queryPrefix );
204+ out .writeOptionalString (passagePrefix );
100205 }
101206
102207 @ Override
@@ -111,10 +216,72 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par
111216 if (additionalConfig != null ) {
112217 builder .field (ADDITIONAL_CONFIG_FIELD , additionalConfig );
113218 }
219+ if (embeddingDimension != null ) {
220+ builder .field (EMBEDDING_DIMENSION_FIELD , embeddingDimension );
221+ }
222+ if (frameworkType != null ) {
223+ builder .field (FRAMEWORK_TYPE_FIELD , frameworkType );
224+ }
225+ if (modelMaxLength != null ) {
226+ builder .field (MODEL_MAX_LENGTH_FIELD , modelMaxLength );
227+ }
228+ if (poolingMode != null ) {
229+ builder .field (POOLING_MODE_FIELD , poolingMode );
230+ }
231+ if (normalizeResult ) {
232+ builder .field (NORMALIZE_RESULT_FIELD , normalizeResult );
233+ }
234+ if (queryPrefix != null ) {
235+ builder .field (QUERY_PREFIX , queryPrefix );
236+ }
237+ if (passagePrefix != null ) {
238+ builder .field (PASSAGE_PREFIX , passagePrefix );
239+ }
114240 builder .endObject ();
115241 return builder ;
116242 }
117243
244+ public enum PoolingMode {
245+ MEAN ("mean" ),
246+ MEAN_SQRT_LEN ("mean_sqrt_len" ),
247+ MAX ("max" ),
248+ WEIGHTED_MEAN ("weightedmean" ),
249+ CLS ("cls" ),
250+ LAST_TOKEN ("lasttoken" );
251+
252+ private String name ;
253+
254+ public String getName () {
255+ return name ;
256+ }
257+
258+ PoolingMode (String name ) {
259+ this .name = name ;
260+ }
261+
262+ public static PoolingMode from (String value ) {
263+ try {
264+ return PoolingMode .valueOf (value .toUpperCase (Locale .ROOT ));
265+ } catch (Exception e ) {
266+ throw new IllegalArgumentException ("Wrong pooling method" );
267+ }
268+ }
269+ }
270+
271+ public enum FrameworkType {
272+ HUGGINGFACE_TRANSFORMERS ,
273+ SENTENCE_TRANSFORMERS ,
274+ HUGGINGFACE_TRANSFORMERS_NEURON ;
275+
276+ public static FrameworkType from (String value ) {
277+ try {
278+ return FrameworkType .valueOf (value .toUpperCase (Locale .ROOT ));
279+ } catch (Exception e ) {
280+ throw new IllegalArgumentException ("Wrong framework type" );
281+ }
282+ }
283+ }
284+
118285 protected void validateNoDuplicateKeys (String allConfig , Map <String , Object > additionalConfig ) {
119286 if (allConfig == null || additionalConfig == null || additionalConfig .isEmpty ()) {
120287 return ;
0 commit comments