|
19 | 19 | import functools |
20 | 20 | import json |
21 | 21 | import logging |
| 22 | +import mimetypes |
22 | 23 | from typing import Any, Iterator, Optional, Union |
23 | 24 | from urllib.parse import urlencode |
24 | 25 |
|
@@ -528,6 +529,22 @@ def _get_sandbox_operation( |
528 | 529 | self._api_client._verify_response(return_value) |
529 | 530 | return return_value |
530 | 531 |
|
| 532 | + _NEEDED_BASE64_ENCODING_MIME_TYPES = [ |
| 533 | + "image/jpeg", |
| 534 | + "image/png", |
| 535 | + "image/gif", |
| 536 | + "image/bmp", |
| 537 | + "image/webp", |
| 538 | + "audio/mpeg", |
| 539 | + "audio/wav", |
| 540 | + "audio/aac", |
| 541 | + "audio/ogg", |
| 542 | + "video/mp4", |
| 543 | + "video/webm", |
| 544 | + "video/ogg", |
| 545 | + "video/mpeg", |
| 546 | + ] |
| 547 | + |
531 | 548 | def create( |
532 | 549 | self, |
533 | 550 | *, |
@@ -618,20 +635,66 @@ def execute_code( |
618 | 635 | Returns: |
619 | 636 | ExecuteSandboxEnvironmentResponse: The response from executing the code. |
620 | 637 | """ |
621 | | - json_string = json.dumps(input_data) |
622 | | - |
623 | | - base64_bytes = base64.b64encode(json_string.encode("utf-8")) |
624 | | - base64_string = base64_bytes.decode("utf-8") |
| 638 | + input_chunks = [] |
| 639 | + |
| 640 | + if input_data.get("code") is not None: |
| 641 | + code = input_data.get("code", "") |
| 642 | + json_code = json.dumps({"code": code}).encode("utf-8") |
| 643 | + input_chunks.append( |
| 644 | + types.Chunk( |
| 645 | + mime_type="application/json", |
| 646 | + data=json_code, |
| 647 | + ) |
| 648 | + ) |
625 | 649 |
|
626 | | - # Only single JSON input is supported for now. |
627 | | - inputs = [{"mime_type": "application/json", "data": base64_string}] |
| 650 | + for file in input_data.get("files", []): |
| 651 | + file_name = file.get("name", "") |
| 652 | + mime_type = file.get("mimeType", "") |
| 653 | + if mime_type is None: |
| 654 | + mime_type, _ = mimetypes.guess_type(file_name) |
| 655 | + if mime_type in self._NEEDED_BASE64_ENCODING_MIME_TYPES: |
| 656 | + base64_bytes = base64.b64encode(file.get("content", b"")) |
| 657 | + content = base64_bytes.decode("utf-8") |
| 658 | + else: |
| 659 | + content = file.get("content", b"") |
| 660 | + input_chunks.append( |
| 661 | + types.Chunk( |
| 662 | + mime_type=mime_type, |
| 663 | + data=content, |
| 664 | + metadata={"attributes": {"file_name": file_name.encode("utf-8")}}, |
| 665 | + ) |
| 666 | + ) |
628 | 667 |
|
629 | 668 | response = self._execute_code( |
630 | 669 | name=name, |
631 | | - inputs=inputs, |
| 670 | + inputs=input_chunks, |
632 | 671 | config=config, |
633 | 672 | ) |
634 | 673 |
|
| 674 | + output_chunks = [] |
| 675 | + for output in response.outputs: |
| 676 | + if output.mime_type != "application/json": |
| 677 | + mime_type = output.mime_type |
| 678 | + # if mime_type is not available, try to guess the mime_type from the file_name. |
| 679 | + if ( |
| 680 | + mime_type is None |
| 681 | + and output.metadata is not None |
| 682 | + and output.metadata.attributes is not None |
| 683 | + ): |
| 684 | + file_name = output.metadata.attributes.get("file_name", b"").decode( |
| 685 | + "utf-8" |
| 686 | + ) |
| 687 | + mime_type, _ = mimetypes.guess_type(file_name) |
| 688 | + output.mime_type = mime_type |
| 689 | + |
| 690 | + # if the mime_type is in the list of mime_types that need base64 encoding, |
| 691 | + # decode the data. |
| 692 | + if mime_type in self._NEEDED_BASE64_ENCODING_MIME_TYPES: |
| 693 | + output.data = base64.b64decode(output.data) |
| 694 | + output_chunks.append(output) |
| 695 | + |
| 696 | + response = types.ExecuteSandboxEnvironmentResponse(outputs=output_chunks) |
| 697 | + |
635 | 698 | return response |
636 | 699 |
|
637 | 700 | def get( |
|
0 commit comments