When Attempting to loadSavedModel, I Encountered 'java.lang.Exception: Could Not Retrieve the SavedModelBundle + () #14215
Is there an existing issue for this?
- I have searched the existing issues and did not find a match.
Who can help?
No response
What are you working on?
I fine-tuned a T5 model from Hugging Face and wanted to import it into Spark NLP.
Current Behavior
I followed the instructions from [https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/transformers/HuggingFace%20in%20Spark%20NLP%20-%20BERT.ipynb], and as I tried to loadSavedModel in Spark NLP, it reported "Py4JJavaError: An error occurred while calling z:com.johnsnowlabs.nlp.annotators.seq2seq.BartTransformer.loadSavedModel: java.lang.Exception: Could not retrieve the SavedModelBundle + ()". Afterward, I attempted to import several other models like t5-base and bart from Hugging Face, and I encountered the same problem.
Expected Behavior
I hope that my fine-tuned T5 model can be successfully loaded in the spark nlp. Or if possible, could anyone make a similar instruction that how to import seq2seq model such as T5 and Bart from huggingface to sparknlp?
Steps To Reproduce
MODEL_NAME = 'google-t5/t5-small'
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)
from transformers import TFAutoModelForSeq2SeqLM, TFT5ForConditionalGeneration
model = TFT5ForConditionalGeneration.from_pretrained(MODEL_NAME)
@tf. function(
tf.TensorSpec(name="input_ids", shape=(None, None), dtype=tf.int32),
tf.TensorSpec(name="attention_mask", shape=(None, None), dtype=tf.int32)
def serving_fn(input_ids, attention_mask):
outputs = model.generate(
# length_penalty=0.9,
# repetition_penalty=2.0,
# num_beams=4,
# early_stopping=True,
return {"sequences": outputs["sequences"]}
model.save_pretrained("./{}".format(MODEL_NAME), saved_model=True, signatures={'serving_default':serving_fn})
asset_path = '{}/saved_model/1/assets'.format(MODEL_NAME)
!cp {MODEL_NAME}_tokenizer/spiece.model {asset_path}
labels = model.config.id2label
labels = [value for key, value in sorted(labels.items(), reverse=False)]
with open(asset_path+'/labels.txt', 'w') as f:
spark = sparknlp.start()
from sparknlp.annotator import *
T5 = T5Transformer.loadSavedModel('{}/saved_model/1'.format(MODEL_NAME), spark)
Spark NLP version and Apache Spark
spark nlp version: '5.3.2'
spark version: '3.5.1'
tensorflow version: '2.15.0'
transformer version: '4.39.1'
Type of Spark Application
Python Application
Java Version
No response
Java Home Directory
No response
Setup and installation
Operating System and Version
Link to your project (if available)
No response
Additional Information
No response