Skip to content

Commit f4fd4e7

Browse files
authored
SPARKNLP Introducing LLAMA 3 (#14379)
* LLAMA3 Scala API * LLAMA3 Python API * LLAMA3 Notebook * LLAMA3 doc update
1 parent 9285df8 commit f4fd4e7

File tree

11 files changed

+4341
-3
lines changed

11 files changed

+4341
-3
lines changed

examples/python/transformers/openvino/HuggingFace_OpenVINO_in_Spark_NLP_LLama3.ipynb

Lines changed: 2657 additions & 0 deletions
Large diffs are not rendered by default.

python/sparknlp/annotator/seq2seq/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@
2525
from sparknlp.annotator.seq2seq.nllb_transformer import *
2626
from sparknlp.annotator.seq2seq.cpm_transformer import *
2727
from sparknlp.annotator.seq2seq.qwen_transformer import *
28-
from sparknlp.annotator.seq2seq.starcoder_transformer import *
28+
from sparknlp.annotator.seq2seq.starcoder_transformer import *
29+
from sparknlp.annotator.seq2seq.llama3_transformer import *
Lines changed: 381 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,381 @@
1+
# Copyright 2017-2022 John Snow Labs
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Contains classes for the LLAMA3Transformer."""
15+
16+
from sparknlp.common import *
17+
18+
19+
class LLAMA3Transformer(AnnotatorModel, HasBatchedAnnotate, HasEngine):
20+
"""Llama 3: Cutting-Edge Foundation and Fine-Tuned Chat Models
21+
22+
The Llama 3 release introduces a new family of pretrained and fine-tuned LLMs, ranging in scale
23+
from 8B and 70B parameters. Llama 3 models are designed with enhanced
24+
efficiency, performance, and safety, making them more capable than previous versions. These models
25+
are trained on a more diverse and expansive dataset, offering improved contextual understanding
26+
and generation quality.
27+
28+
The fine-tuned models, known as Llama 3-instruct, are optimized for dialogue applications using an advanced
29+
version of Reinforcement Learning from Human Feedback (RLHF). Llama 3-instruct models demonstrate superior
30+
performance across multiple benchmarks, outperforming Llama 2 and even matching or exceeding the capabilities
31+
of some closed-source models.
32+
33+
Pretrained models can be loaded with :meth:`.pretrained` of the companion
34+
object:
35+
36+
>>> llama3 = LLAMA3Transformer.pretrained() \\
37+
... .setInputCols(["document"]) \\
38+
... .setOutputCol("generation")
39+
40+
41+
The default model is ``"llama3-7b"``, if no name is provided. For available
42+
pretrained models please see the `Models Hub
43+
<https://sparknlp.org/models?q=llama3>`__.
44+
45+
====================== ======================
46+
Input Annotation types Output Annotation type
47+
====================== ======================
48+
``DOCUMENT`` ``DOCUMENT``
49+
====================== ======================
50+
51+
Parameters
52+
----------
53+
configProtoBytes
54+
ConfigProto from tensorflow, serialized into byte array.
55+
minOutputLength
56+
Minimum length of the sequence to be generated, by default 0
57+
maxOutputLength
58+
Maximum length of output text, by default 60
59+
doSample
60+
Whether or not to use sampling; use greedy decoding otherwise, by default False
61+
temperature
62+
The value used to modulate the next token probabilities, by default 1.0
63+
topK
64+
The number of highest probability vocabulary tokens to keep for
65+
top-k-filtering, by default 40
66+
topP
67+
Top cumulative probability for vocabulary tokens, by default 1.0
68+
69+
If set to float < 1, only the most probable tokens with probabilities
70+
that add up to ``topP`` or higher are kept for generation.
71+
repetitionPenalty
72+
The parameter for repetition penalty, 1.0 means no penalty. , by default
73+
1.0
74+
noRepeatNgramSize
75+
If set to int > 0, all ngrams of that size can only occur once, by
76+
default 0
77+
ignoreTokenIds
78+
A list of token ids which are ignored in the decoder's output, by
79+
default []
80+
81+
Notes
82+
-----
83+
This is a very computationally expensive module, especially on larger
84+
sequences. The use of an accelerator such as GPU is recommended.
85+
86+
References
87+
----------
88+
- `Llama 3: Cutting-Edge Foundation and Fine-Tuned Chat Models
89+
<https://ai.meta.com/blog/meta-llama-3/>`__
90+
- https://github.com/facebookresearch/llama
91+
92+
**Paper Abstract:**
93+
94+
*Llama 3 is the latest iteration of large language models from Meta, offering a range of models
95+
from 1 billion to 70 billion parameters. The fine-tuned versions, known as Llama 3-Chat, are
96+
specifically designed for dialogue applications and have been optimized using advanced techniques
97+
such as RLHF. Llama 3 models show remarkable improvements in both safety and performance, making
98+
them a leading choice in both open-source and commercial settings. Our comprehensive approach to
99+
training and fine-tuning these models is aimed at encouraging responsible AI development and fostering
100+
community collaboration.*
101+
102+
Examples
103+
--------
104+
>>> import sparknlp
105+
>>> from sparknlp.base import *
106+
>>> from sparknlp.annotator import *
107+
>>> from pyspark.ml import Pipeline
108+
>>> documentAssembler = DocumentAssembler() \\
109+
... .setInputCol("text") \\
110+
... .setOutputCol("documents")
111+
>>> llama3 = LLAMA3Transformer.pretrained("llama_3_7b_chat_hf_int8") \\
112+
... .setInputCols(["documents"]) \\
113+
... .setMaxOutputLength(60) \\
114+
... .setOutputCol("generation")
115+
>>> pipeline = Pipeline().setStages([documentAssembler, llama3])
116+
>>> data = spark.createDataFrame([
117+
... (
118+
... 1,
119+
... "<|start_header_id|>system<|end_header_id|> \\n"+ \
120+
... "You are a minion chatbot who always responds in minion speak! \\n" + \
121+
... "<|start_header_id|>user<|end_header_id|> \\n" + \
122+
... "Who are you? \\n" + \
123+
... "<|start_header_id|>assistant<|end_header_id|> \\n"
124+
... )
125+
... ]).toDF("id", "text")
126+
>>> result = pipeline.fit(data).transform(data)
127+
>>> result.select("generation.result").show(truncate=False)
128+
+------------------------------------------------+
129+
|result |
130+
+------------------------------------------------+
131+
|[Oooh, me am Minion! Me help you with things! Me speak Minion language, yeah! Bana-na-na!]|
132+
+------------------------------------------------+
133+
"""
134+
135+
136+
name = "LLAMA3Transformer"
137+
138+
inputAnnotatorTypes = [AnnotatorType.DOCUMENT]
139+
140+
outputAnnotatorType = AnnotatorType.DOCUMENT
141+
142+
143+
configProtoBytes = Param(Params._dummy(),
144+
"configProtoBytes",
145+
"ConfigProto from tensorflow, serialized into byte array. Get with config_proto.SerializeToString()",
146+
TypeConverters.toListInt)
147+
148+
minOutputLength = Param(Params._dummy(), "minOutputLength", "Minimum length of the sequence to be generated",
149+
typeConverter=TypeConverters.toInt)
150+
151+
maxOutputLength = Param(Params._dummy(), "maxOutputLength", "Maximum length of output text",
152+
typeConverter=TypeConverters.toInt)
153+
154+
doSample = Param(Params._dummy(), "doSample", "Whether or not to use sampling; use greedy decoding otherwise",
155+
typeConverter=TypeConverters.toBoolean)
156+
157+
temperature = Param(Params._dummy(), "temperature", "The value used to module the next token probabilities",
158+
typeConverter=TypeConverters.toFloat)
159+
160+
topK = Param(Params._dummy(), "topK",
161+
"The number of highest probability vocabulary tokens to keep for top-k-filtering",
162+
typeConverter=TypeConverters.toInt)
163+
164+
topP = Param(Params._dummy(), "topP",
165+
"If set to float < 1, only the most probable tokens with probabilities that add up to ``top_p`` or higher are kept for generation",
166+
typeConverter=TypeConverters.toFloat)
167+
168+
repetitionPenalty = Param(Params._dummy(), "repetitionPenalty",
169+
"The parameter for repetition penalty. 1.0 means no penalty. See `this paper <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details",
170+
typeConverter=TypeConverters.toFloat)
171+
172+
noRepeatNgramSize = Param(Params._dummy(), "noRepeatNgramSize",
173+
"If set to int > 0, all ngrams of that size can only occur once",
174+
typeConverter=TypeConverters.toInt)
175+
176+
ignoreTokenIds = Param(Params._dummy(), "ignoreTokenIds",
177+
"A list of token ids which are ignored in the decoder's output",
178+
typeConverter=TypeConverters.toListInt)
179+
180+
beamSize = Param(Params._dummy(), "beamSize",
181+
"The number of beams to use for beam search",
182+
typeConverter=TypeConverters.toInt)
183+
184+
stopTokenIds = Param(Params._dummy(), "stopTokenIds",
185+
"A list of token ids which are considered as stop tokens in the decoder's output",
186+
typeConverter=TypeConverters.toListInt)
187+
188+
189+
def setIgnoreTokenIds(self, value):
190+
"""A list of token ids which are ignored in the decoder's output.
191+
192+
Parameters
193+
----------
194+
value : List[int]
195+
The words to be filtered out
196+
"""
197+
return self._set(ignoreTokenIds=value)
198+
199+
def setConfigProtoBytes(self, b):
200+
"""Sets configProto from tensorflow, serialized into byte array.
201+
202+
Parameters
203+
----------
204+
b : List[int]
205+
ConfigProto from tensorflow, serialized into byte array
206+
"""
207+
return self._set(configProtoBytes=b)
208+
209+
def setMinOutputLength(self, value):
210+
"""Sets minimum length of the sequence to be generated.
211+
212+
Parameters
213+
----------
214+
value : int
215+
Minimum length of the sequence to be generated
216+
"""
217+
return self._set(minOutputLength=value)
218+
219+
def setMaxOutputLength(self, value):
220+
"""Sets maximum length of output text.
221+
222+
Parameters
223+
----------
224+
value : int
225+
Maximum length of output text
226+
"""
227+
return self._set(maxOutputLength=value)
228+
229+
def setDoSample(self, value):
230+
"""Sets whether or not to use sampling, use greedy decoding otherwise.
231+
232+
Parameters
233+
----------
234+
value : bool
235+
Whether or not to use sampling; use greedy decoding otherwise
236+
"""
237+
return self._set(doSample=value)
238+
239+
def setTemperature(self, value):
240+
"""Sets the value used to module the next token probabilities.
241+
242+
Parameters
243+
----------
244+
value : float
245+
The value used to module the next token probabilities
246+
"""
247+
return self._set(temperature=value)
248+
249+
def setTopK(self, value):
250+
"""Sets the number of highest probability vocabulary tokens to keep for
251+
top-k-filtering.
252+
253+
Parameters
254+
----------
255+
value : int
256+
Number of highest probability vocabulary tokens to keep
257+
"""
258+
return self._set(topK=value)
259+
260+
def setTopP(self, value):
261+
"""Sets the top cumulative probability for vocabulary tokens.
262+
263+
If set to float < 1, only the most probable tokens with probabilities
264+
that add up to ``topP`` or higher are kept for generation.
265+
266+
Parameters
267+
----------
268+
value : float
269+
Cumulative probability for vocabulary tokens
270+
"""
271+
return self._set(topP=value)
272+
273+
def setRepetitionPenalty(self, value):
274+
"""Sets the parameter for repetition penalty. 1.0 means no penalty.
275+
276+
Parameters
277+
----------
278+
value : float
279+
The repetition penalty
280+
281+
References
282+
----------
283+
See `Ctrl: A Conditional Transformer Language Model For Controllable
284+
Generation <https://arxiv.org/pdf/1909.05858.pdf>`__ for more details.
285+
"""
286+
return self._set(repetitionPenalty=value)
287+
288+
def setNoRepeatNgramSize(self, value):
289+
"""Sets size of n-grams that can only occur once.
290+
291+
If set to int > 0, all ngrams of that size can only occur once.
292+
293+
Parameters
294+
----------
295+
value : int
296+
N-gram size can only occur once
297+
"""
298+
return self._set(noRepeatNgramSize=value)
299+
300+
def setBeamSize(self, value):
301+
"""Sets the number of beams to use for beam search.
302+
303+
Parameters
304+
----------
305+
value : int
306+
The number of beams to use for beam search
307+
"""
308+
return self._set(beamSize=value)
309+
310+
def setStopTokenIds(self, value):
311+
"""Sets a list of token ids which are considered as stop tokens in the decoder's output.
312+
313+
Parameters
314+
----------
315+
value : List[int]
316+
The words to be considered as stop tokens
317+
"""
318+
return self._set(stopTokenIds=value)
319+
320+
@keyword_only
321+
def __init__(self, classname="com.johnsnowlabs.nlp.annotators.seq2seq.LLAMA3Transformer", java_model=None):
322+
super(LLAMA3Transformer, self).__init__(
323+
classname=classname,
324+
java_model=java_model
325+
)
326+
self._setDefault(
327+
minOutputLength=0,
328+
maxOutputLength=20,
329+
doSample=False,
330+
temperature=0.6,
331+
topK=-1,
332+
topP=0.9,
333+
repetitionPenalty=1.0,
334+
noRepeatNgramSize=3,
335+
ignoreTokenIds=[],
336+
batchSize=1,
337+
beamSize=1,
338+
stopTokenIds=[128001,]
339+
)
340+
341+
@staticmethod
342+
def loadSavedModel(folder, spark_session, use_openvino = False):
343+
"""Loads a locally saved model.
344+
345+
Parameters
346+
----------
347+
folder : str
348+
Folder of the saved model
349+
spark_session : pyspark.sql.SparkSession
350+
The current SparkSession
351+
352+
Returns
353+
-------
354+
LLAMA3Transformer
355+
The restored model
356+
"""
357+
from sparknlp.internal import _LLAMA3Loader
358+
jModel = _LLAMA3Loader(folder, spark_session._jsparkSession, use_openvino)._java_obj
359+
return LLAMA3Transformer(java_model=jModel)
360+
361+
@staticmethod
362+
def pretrained(name="llama_3_7b_chat_hf_int4", lang="en", remote_loc=None):
363+
"""Downloads and loads a pretrained model.
364+
365+
Parameters
366+
----------
367+
name : str, optional
368+
Name of the pretrained model, by default "llama_2_7b_chat_hf_int4"
369+
lang : str, optional
370+
Language of the pretrained model, by default "en"
371+
remote_loc : str, optional
372+
Optional remote address of the resource, by default None. Will use
373+
Spark NLPs repositories otherwise.
374+
375+
Returns
376+
-------
377+
LLAMA3Transformer
378+
The restored model
379+
"""
380+
from sparknlp.pretrained import ResourceDownloader
381+
return ResourceDownloader.downloadModel(LLAMA3Transformer, name, lang, remote_loc)

0 commit comments

Comments
 (0)