Skip to content

Commit 22daf30

Browse files
added the capability to add tables
Signed-off-by: Peter Staar <taa@zurich.ibm.com>
1 parent 96a44b0 commit 22daf30

File tree

1 file changed

+195
-26
lines changed

1 file changed

+195
-26
lines changed

examples/smolagents/agents.py

Lines changed: 195 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,12 @@
2828
DoclingDocument,
2929
GroupItem,
3030
TitleItem,
31+
TableItem,
3132
SectionHeaderItem,
3233
TextItem,
3334
ListItem,
3435
LevelNumber,
36+
DocItemLabel,
3537
)
3638

3739

@@ -76,6 +78,7 @@ class BaseDoclingAgent(BaseModel):
7678
model: Model
7779
tools: list[Tool]
7880
chat_history: list[ChatMessage]
81+
max_iteration: int = 16
7982

8083
class Config:
8184
arbitrary_types_allowed = True # Needed for complex types like Model
@@ -141,18 +144,22 @@ class DoclingWritingAgent(BaseDoclingAgent):
141144
142145
...
143146
144-
## References
147+
## <final section header>
145148
146149
list: <1 sentence summary of what the list enumerates>
147150
```
148151
149-
Make sure that the Markdown outline is always enclosed in ```markdown <markdown-content> ```!
152+
Make sure that the Markdown outline is always enclosed in ```markdown <markdown-content>```!
150153
"""
151154

152155
system_prompt_expert_writer: ClassVar[
153156
str
154-
] = """You are an expert writer that needs to write a single paragraph, table
155-
or nested list based on a summary. Really stick to the summary and be specific, but do not write on adjacent topics
157+
] = """You are an expert writer that needs to write a single paragraph, table or nested list based on a summary. Really stick to the summary and be specific, but do not write on adjacent topics.
158+
"""
159+
160+
system_prompt_expert_table_writer: ClassVar[
161+
str
162+
] = """You are an expert writer that needs to write a single HTML table based on a summary. Really stick to the summary. Try to make interesting tables and leverage multi-column headers. If you have units in the table, make sure the units are in the column or row-headers of the table.
156163
"""
157164

158165
def __init__(self, *, model: Model, tools: list[Tool]):
@@ -182,7 +189,7 @@ def run(self, task: str, **kwargs):
182189

183190
document: DoclingDocument = self._make_outline_for_writing(task=task)
184191

185-
document = self._populate_document_with_content(task=task, document=document)
192+
document = self._populate_document_with_content(task=task, outline=document)
186193

187194
print(document.export_to_markdown(text_width=72))
188195

@@ -244,37 +251,140 @@ def _make_outline_for_writing(self, *, task: str) -> DoclingDocument:
244251
),
245252
]
246253

247-
output = self.model.generate(messages=chat_messages)
254+
iteration = 0
255+
while iteration < self.max_iteration:
256+
iteration += 1
257+
logger.info(f"_make_outline_for_writing: iteration {iteration}")
258+
259+
output = self.model.generate(messages=chat_messages)
260+
261+
results = self._analyse_output_into_docling_document(message=output)
262+
263+
if len(results) == 0:
264+
chat_messages.append(
265+
ChatMessage(
266+
role=MessageRole.USER,
267+
content=[
268+
{
269+
"type": "text",
270+
"text": f"I see now markdown section. Please try again and add a markdown section in the format ```markdown <insert-content>``` for task: {task}!",
271+
}
272+
],
273+
)
274+
)
275+
continue
276+
elif len(results) > 1:
277+
chat_messages.append(
278+
ChatMessage(
279+
role=MessageRole.USER,
280+
content=[
281+
{
282+
"type": "text",
283+
"text": f"I see multiple markdown sections. Please try again and only add a single markdown section in the format ```markdown <insert-content>``` for task: {task}!",
284+
}
285+
],
286+
)
287+
)
288+
continue
289+
else:
290+
logger.info("We obtained a markdown for the outline!")
291+
292+
document = results[0]
293+
logger.info(f"outline: {document.export_to_markdown()}")
294+
295+
starts = [
296+
"paragraph: ",
297+
"table: ",
298+
"picture: ",
299+
"list: ",
300+
]
301+
lines = []
302+
for item, level in document.iterate_items(with_groups=True):
303+
if isinstance(item, TitleItem) or isinstance(item, SectionHeaderItem):
304+
continue
305+
elif isinstance(item, TextItem):
306+
good: bool = False
307+
for start in starts:
308+
if item.text.startswith(start):
309+
good = True
310+
break
311+
312+
if not good:
313+
lines.append(item.text)
314+
315+
logger.info(f"broken lines: {'\n'.join(lines)}")
316+
317+
if len(lines) > 0:
318+
message = f"Every content line should start with one out of the following choices: {starts}. The following lines need to be updated: {'\n'.join(lines)}"
319+
chat_messages.append(
320+
ChatMessage(
321+
role=MessageRole.USER,
322+
content=[{"type": "text", "text": message}],
323+
)
324+
)
325+
else:
326+
self.chat_history.extend(chat_messages)
327+
self.chat_history.append(output)
248328

249-
self.chat_history.extend(chat_messages)
250-
self.chat_history.append(output)
329+
logger.info(
330+
f"Finished an outline for document: {document.export_to_markdown()}"
331+
)
332+
return document
251333

252-
results = self._analyse_output_into_docling_document(message=output)
253-
assert len(results) == 1, (
254-
"We only want to see a single response from the initial task analysis"
255-
)
334+
raise ValueError("Could not make a correct outline!")
256335

257-
document = results[0]
258-
return document
336+
def _populate_document_with_content(
337+
self, *, task: str, outline: DoclingDocument
338+
) -> DoclingDocument:
339+
headers: dict[int, str] = {}
340+
341+
document = DoclingDocument(name=f"report on task: `{task}`")
342+
343+
for item, level in outline.iterate_items(with_groups=True):
344+
if isinstance(item, TitleItem):
345+
headers[0] = item.text
346+
document.add_title(text=item.text)
259347

260-
def _populate_document_with_content(self, *, task: str, document: DoclingDocument):
261-
for item, level in document.iterate_items(with_groups=True):
262-
if isinstance(item, TitleItem) or isinstance(item, SectionHeaderItem):
348+
elif isinstance(item, SectionHeaderItem):
263349
logger.info(f"starting in {item.text}")
350+
headers[item.level] = item.text
351+
352+
document.add_heading(text=item.text, level=item.level)
353+
354+
import copy
355+
356+
keys = copy.deepcopy(list(headers.keys()))
357+
for key in keys:
358+
if key > item.level:
359+
del headers[key]
360+
264361
elif isinstance(item, TextItem):
265362
if item.text.startswith("paragraph:"):
266-
summary = item.text.replace("paragraph:", "").strip()
363+
summary = item.text.replace("paragraph: ", "").strip()
364+
267365
logger.info(f"need to write a paragraph: {summary})")
268366
content = self._write_paragraph(
269367
summary=summary, item_type="paragraph"
270368
)
369+
document.add_text(label=DocItemLabel.TEXT, text=content)
370+
371+
elif item.text.startswith("list:"):
372+
summary = item.text.replace("list:", "").strip()
373+
logger.info(f"need to write a list: {summary}")
271374

272-
item.text = content
375+
document.add_text(label=DocItemLabel.TEXT, text=item.text)
273376

274377
elif item.text.startswith("table:"):
275378
summary = item.text.replace("table:", "").strip()
276379
logger.info(f"need to write a table: {summary}")
277380

381+
table_item = self._write_table(summary=summary)
382+
383+
caption = document.add_text(
384+
label=DocItemLabel.CAPTION, text=summary
385+
)
386+
document.add_table(data=table_item.data, caption=caption)
387+
278388
return document
279389

280390
def _analyse_output_into_docling_document(
@@ -285,25 +395,18 @@ def extract_code_blocks(text, language: str):
285395
matches = re.findall(pattern, text, re.DOTALL)
286396
return matches
287397

288-
print(
289-
f"content: \n\n--------------------\n{message.content}\n--------------------\n"
290-
)
291-
292398
converter = DocumentConverter(allowed_formats=[InputFormat.MD])
293399

294400
result = []
295401
for mtch in extract_code_blocks(message.content, language=language):
296402
md_doc: str = mtch
297-
print("md-doc:\n\n", md_doc)
298403

299404
buff = BytesIO(md_doc.encode("utf-8"))
300405
doc_stream = DocumentStream(name="tmp.md", stream=buff)
301406

302407
conv_result: ConversionResult = converter.convert(doc_stream)
303408
result.append(conv_result.document)
304409

305-
logger.warning(f"#-results: {len(result)}")
306-
307410
return result
308411

309412
def _write_paragraph(
@@ -333,6 +436,72 @@ def _write_paragraph(
333436
output = self.model.generate(messages=chat_messages)
334437
return output.content
335438

439+
def _write_table(self, summary: str, hierarchy: list[str] = []) -> TableItem | None:
440+
def extract_code_blocks(text):
441+
pattern = rf"```html(.*?)```"
442+
matches = re.findall(pattern, text, re.DOTALL)
443+
if len(matches) > 0:
444+
return matches[0]
445+
446+
pattern = rf"<html>(.*?)</html>"
447+
matches = re.findall(pattern, text, re.DOTALL)
448+
if len(matches) > 0:
449+
return matches[0]
450+
451+
return None
452+
453+
chat_messages = [
454+
ChatMessage(
455+
role=MessageRole.SYSTEM,
456+
content=[
457+
{
458+
"type": "text",
459+
"text": self.system_prompt_expert_table_writer,
460+
}
461+
],
462+
),
463+
ChatMessage(
464+
role=MessageRole.USER,
465+
content=[
466+
{
467+
"type": "text",
468+
"text": f"write me a single table in HTML that expands the following summary: {summary}",
469+
}
470+
],
471+
),
472+
]
473+
474+
doc = None
475+
476+
converter = DocumentConverter(allowed_formats=[InputFormat.HTML])
477+
478+
iteration = 0
479+
while iteration < self.max_iteration:
480+
iteration += 1
481+
482+
output = self.model.generate(messages=chat_messages)
483+
print("output:\n\n", output.content)
484+
485+
html_doc: str = extract_code_blocks(output.content)
486+
print("html-doc:\n\n", html_doc)
487+
488+
try:
489+
buff = BytesIO(html_doc.encode("utf-8"))
490+
doc_stream = DocumentStream(name="tmp.html", stream=buff)
491+
492+
conv_result: ConversionResult = converter.convert(doc_stream)
493+
doc = conv_result.document
494+
495+
if doc:
496+
for item, level in doc.iterate_items(with_groups=True):
497+
if isinstance(item, TableItem):
498+
return item
499+
500+
except Exception as exc:
501+
logger.error(f"error with html conversion: {exc}")
502+
503+
return None
504+
336505

337506
def main():
338507
"""

0 commit comments

Comments
 (0)