@@ -33,6 +33,7 @@ class ConversationTemplate(BaseModel):
33
33
bos_tokens_ : List [int ]
34
34
eot_tokens_ : List [int ]
35
35
message_prefix_tokens_ : List [int ]
36
+ system_role_tokens_ : Optional [List [int ]] = []
36
37
37
38
def __init__ (self , ** data ):
38
39
tokenizer = data ["tokenizer" ]
@@ -63,15 +64,14 @@ def tokenize_conversations(self, conversations: Iterable[Conversation], inferenc
63
64
for msg in conv .items :
64
65
role_mappings .add ((msg .role , conv .condition or default_condition ))
65
66
all_text .append (msg .content )
66
-
67
+
68
+ if self .system_as_role :
69
+ self .system_role_tokens_ = self .tokenizer (self .role_prefix ("system" , "" ), add_special_tokens = False ).input_ids + self .message_prefix_tokens_
70
+
67
71
sys_mappings = list (sys_mappings )
68
72
role_mappings = list (role_mappings )
69
73
70
- # Tokenize
71
- if self .system_as_role :
72
- sys_mappings = dict (zip (sys_mappings , self ._tokenize ([self .role_prefix (sys ) for sys in sys_mappings ], ignore_special = False )))
73
- else :
74
- sys_mappings = dict (zip (sys_mappings , self ._tokenize (sys_mappings )))
74
+ sys_mappings = dict (zip (sys_mappings , self ._tokenize (sys_mappings )))
75
75
role_mappings = dict (zip (role_mappings , self ._tokenize ([self .role_prefix (* args ) for args in role_mappings ], ignore_special = False )))
76
76
all_text = self ._tokenize (all_text )
77
77
@@ -89,6 +89,9 @@ def tokenize_conversations(self, conversations: Iterable[Conversation], inferenc
89
89
90
90
# System
91
91
if conv .system :
92
+ tokens .extend (self .system_role_tokens_ )
93
+ weights .extend ([0. ] * len (self .system_role_tokens_ ))
94
+
92
95
system = sys_mappings [conv .system ]
93
96
tokens .extend (system )
94
97
weights .extend ([0. ] * len (system ))
0 commit comments