Skip to content

Commit

Permalink
add organization_id param to LlamaCloudIndex.from_documents (#14947)
Browse files Browse the repository at this point in the history
* add organization_id param to from_documents

* update version

* add org id param to LlamaCloudIndex ctor

* add back org id var
  • Loading branch information
sourabhdesai authored Jul 24, 2024
1 parent d94e0b0 commit ad2b0cd
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def __init__(
transformations: Optional[List[TransformComponent]] = None,
timeout: int = 60,
project_name: str = DEFAULT_PROJECT_NAME,
organization_id: Optional[str] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
app_url: Optional[str] = None,
Expand All @@ -70,6 +71,7 @@ def __init__(
"""Initialize the Platform Index."""
self.name = name
self.project_name = project_name
self.organization_id = organization_id
self.transformations = transformations or []

if nodes is not None:
Expand Down Expand Up @@ -168,9 +170,31 @@ def _wait_for_documents_ingestion(
# the pipeline status is success
self._wait_for_pipeline_ingestion(verbose, raise_on_error)

def _get_project_id(self) -> str:
projects = self._client.projects.list_projects(
organization_id=self.organization_id,
project_name=self.project_name,
)
if len(projects) == 0:
raise ValueError(
f"Unknown project name {self.project_name}. Please confirm a "
"managed project with this name exists."
)
elif len(projects) > 1:
raise ValueError(
f"Multiple projects found with name {self.project_name}. Please specify organization_id."
)
project = projects[0]

if project.id is None:
raise ValueError(f"No project found with name {self.project_name}")

return project.id

def _get_pipeline_id(self) -> str:
project_id = self._get_project_id()
pipelines = self._client.pipelines.search_pipelines(
project_name=self.project_name,
project_id=project_id,
pipeline_name=self.name,
pipeline_type=PipelineType.MANAGED.value,
)
Expand Down Expand Up @@ -199,6 +223,7 @@ def from_documents( # type: ignore
name: str,
transformations: Optional[List[TransformComponent]] = None,
project_name: str = DEFAULT_PROJECT_NAME,
organization_id: Optional[str] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
app_url: Optional[str] = None,
Expand All @@ -221,7 +246,7 @@ def from_documents( # type: ignore
)

project = client.projects.upsert_project(
request=ProjectCreate(name=project_name)
organization_id=organization_id, request=ProjectCreate(name=project_name)
)
if project.id is None:
raise ValueError(f"Failed to create/get project {project_name}")
Expand All @@ -240,6 +265,7 @@ def from_documents( # type: ignore
name,
transformations=transformations,
project_name=project_name,
organization_id=project.organization_id,
api_key=api_key,
base_url=base_url,
app_url=app_url,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-indices-managed-llama-cloud"
readme = "README.md"
version = "0.2.5"
version = "0.2.6"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
from llama_index.core.indices.managed.base import BaseManagedIndex
from llama_index.indices.managed.llama_cloud import LlamaCloudIndex
from llama_index.core.schema import Document
Expand All @@ -8,6 +9,7 @@
base_url = os.environ.get("LLAMA_CLOUD_BASE_URL", None)
api_key = os.environ.get("LLAMA_CLOUD_API_KEY", None)
openai_api_key = os.environ.get("OPENAI_API_KEY", None)
organization_id = os.environ.get("LLAMA_CLOUD_ORGANIZATION_ID", None)


def test_class():
Expand Down Expand Up @@ -36,12 +38,13 @@ def test_retrieve():
assert response is not None and len(response.response) > 0


@pytest.mark.parametrize("organization_id", [None, organization_id])
@pytest.mark.skipif(
not base_url or not api_key, reason="No platform base url or api keyset"
)
@pytest.mark.skipif(not openai_api_key, reason="No openai api key set")
@pytest.mark.integration()
def test_documents_crud():
def test_documents_crud(organization_id: Optional[str]):
os.environ["OPENAI_API_KEY"] = openai_api_key
documents = [
Document(text="Hello world.", doc_id="1", metadata={"source": "test"}),
Expand All @@ -51,6 +54,8 @@ def test_documents_crud():
name=f"test pipeline {uuid4()}",
api_key=api_key,
base_url=base_url,
organization_id=organization_id,
verbose=True,
)
docs = index.ref_doc_info
assert len(docs) == 1
Expand All @@ -61,7 +66,8 @@ def test_documents_crud():
assert all(n.node.metadata["source"] == "test" for n in nodes)

index.insert(
Document(text="Hello world.", doc_id="2", metadata={"source": "inserted"})
Document(text="Hello world.", doc_id="2", metadata={"source": "inserted"}),
verbose=True,
)
docs = index.ref_doc_info
assert len(docs) == 2
Expand All @@ -73,7 +79,8 @@ def test_documents_crud():
assert any(n.node.ref_doc_id == "2" for n in nodes)

index.update_ref_doc(
Document(text="Hello world.", doc_id="2", metadata={"source": "updated"})
Document(text="Hello world.", doc_id="2", metadata={"source": "updated"}),
verbose=True,
)
docs = index.ref_doc_info
assert len(docs) == 2
Expand All @@ -90,7 +97,7 @@ def test_documents_crud():
assert docs["3"].metadata["source"] == "refreshed"
assert docs["1"].metadata["source"] == "refreshed"

index.delete_ref_doc("3")
index.delete_ref_doc("3", verbose=True)
docs = index.ref_doc_info
assert len(docs) == 2
assert "3" not in docs

0 comments on commit ad2b0cd

Please sign in to comment.