Skip to content

Commit

Permalink
py: For multipart endpoint don't split into batches of 20mb (#1067)
Browse files Browse the repository at this point in the history
- Multipart endpoint supports batches of any size
- Add content-length header to each part
  • Loading branch information
nfcampos authored Oct 7, 2024
2 parents da3c1bb + fac9500 commit d67469f
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 59 deletions.
56 changes: 27 additions & 29 deletions python/langsmith/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class ZoneInfo: # type: ignore[no-redef]
WARNED_ATTACHMENTS = False
EMPTY_SEQ: tuple[Dict, ...] = ()
BOUNDARY = uuid.uuid4().hex
MultipartParts = List[Tuple[str, Tuple[None, bytes, str]]]
MultipartParts = List[Tuple[str, Tuple[None, bytes, str, Dict[str, str]]]]
URLLIB3_SUPPORTS_BLOCKSIZE = "key_blocksize" in signature(PoolKey).parameters


Expand Down Expand Up @@ -1638,63 +1638,61 @@ def multipart_ingest_runs(
# insert runtime environment
self._insert_runtime_env(create_dicts)
self._insert_runtime_env(update_dicts)
# check size limit
size_limit_bytes = (self.info.batch_ingest_config or {}).get(
"size_limit_bytes"
) or _SIZE_LIMIT_BYTES
# send the runs in multipart requests
acc_size = 0
acc_context: List[str] = []
acc_parts: MultipartParts = []
for event, payloads in (("post", create_dicts), ("patch", update_dicts)):
for payload in payloads:
parts: MultipartParts = []
# collect fields to be sent as separate parts
fields = [
("inputs", payload.pop("inputs", None)),
("outputs", payload.pop("outputs", None)),
("events", payload.pop("events", None)),
]
# encode the main run payload
parts.append(
payloadb = _dumps_json(payload)
acc_parts.append(
(
f"{event}.{payload['id']}",
(None, _dumps_json(payload), "application/json"),
(
None,
payloadb,
"application/json",
{"Content-Length": str(len(payloadb))},
),
)
)
# encode the fields we collected
for key, value in fields:
if value is None:
continue
parts.append(
valb = _dumps_json(value)
acc_parts.append(
(
f"{event}.{payload['id']}.{key}",
(None, _dumps_json(value), "application/json"),
(
None,
valb,
"application/json",
{"Content-Length": str(len(valb))},
),
),
)
# encode the attachments
if attachments := all_attachments.pop(payload["id"], None):
for n, (ct, ba) in attachments.items():
parts.append(
(f"attachment.{payload['id']}.{n}", (None, ba, ct))
acc_parts.append(
(
f"attachment.{payload['id']}.{n}",
(None, ba, ct, {"Content-Length": str(len(ba))}),
)
)
# calculate the size of the parts
size = sum(len(p[1][1]) for p in parts)
# compute context
context = f"trace={payload.get('trace_id')},id={payload.get('id')}"
# if next size would exceed limit, send the current parts
if acc_size + size > size_limit_bytes:
self._send_multipart_req(acc_parts, _context="; ".join(acc_context))
acc_parts.clear()
acc_context.clear()
acc_size = 0
# accumulate the parts
acc_size += size
acc_parts.extend(parts)
acc_context.append(context)
# send the remaining parts
if acc_parts:
self._send_multipart_req(acc_parts, _context="; ".join(acc_context))
acc_context.append(
f"trace={payload.get('trace_id')},id={payload.get('id')}"
)
# send the request
self._send_multipart_req(acc_parts, _context="; ".join(acc_context))

def _send_multipart_req(self, parts: MultipartParts, *, _context: str):
for api_url, api_key in self._write_api_urls.items():
Expand Down
41 changes: 21 additions & 20 deletions python/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ requests-toolbelt = "^1.0.0"
pytest = "^7.3.1"
black = ">=23.3,<25.0"
mypy = "^1.9.0"
ruff = "^0.3.4"
ruff = "^0.6.9"
types-requests = "^2.31.0.1"
pandas-stubs = "^2.0.1.230501"
types-pyyaml = "^6.0.12.10"
Expand Down
8 changes: 4 additions & 4 deletions python/tests/integration_tests/wrappers/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_chat_sync_api(mock_session: mock.MagicMock, stream: bool):
assert len(original_chunks) == len(patched_chunks)
assert [o.choices == p.choices for o, p in zip(original_chunks, patched_chunks)]
else:
assert type(original) == type(patched)
assert type(original) is type(patched)
assert original.choices == patched.choices
# Give the thread a chance.
time.sleep(0.01)
Expand Down Expand Up @@ -74,7 +74,7 @@ async def test_chat_async_api(mock_session: mock.MagicMock, stream: bool):
assert len(original_chunks) == len(patched_chunks)
assert [o.choices == p.choices for o, p in zip(original_chunks, patched_chunks)]
else:
assert type(original) == type(patched)
assert type(original) is type(patched)
assert original.choices == patched.choices
# Give the thread a chance.
time.sleep(0.1)
Expand Down Expand Up @@ -117,7 +117,7 @@ def test_completions_sync_api(mock_session: mock.MagicMock, stream: bool):
assert original.response
assert patched.response
else:
assert type(original) == type(patched)
assert type(original) is type(patched)
assert original.choices == patched.choices
# Give the thread a chance.
time.sleep(0.1)
Expand Down Expand Up @@ -170,7 +170,7 @@ async def test_completions_async_api(mock_session: mock.MagicMock, stream: bool)
assert original.response
assert patched.response
else:
assert type(original) == type(patched)
assert type(original) is type(patched)
assert original.choices == patched.choices
# Give the thread a chance.
for _ in range(10):
Expand Down
7 changes: 2 additions & 5 deletions python/tests/unit_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,11 +1061,8 @@ def test_batch_ingest_run_splits_large_batches(
]
if use_multipart_endpoint:
client.multipart_ingest_runs(create=posts, update=patches)
# we can support up to 20MB per batch, so we need to find the number of batches
# we should be sending
max_in_batch = max(1, (20 * MB) // (payload_size + 20))

expected_num_requests = min(6, math.ceil((len(run_ids) * 2) / max_in_batch))
# multipart endpoint should only send one request
expected_num_requests = 1
# count the number of POST requests
assert sum(
[1 for call in mock_session.request.call_args_list if call[0][0] == "POST"]
Expand Down

0 comments on commit d67469f

Please sign in to comment.