Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Nutanix AI Endpoint #346

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
| Chroma | Single Node | | | :heavy_check_mark: | | |
| PG Vector | Single Node | | | :heavy_check_mark: | | |
| PyTorch ExecuTorch | On-device iOS | :heavy_check_mark: | :heavy_check_mark: | | |
| Nutanix AI | Hosted | | :heavy_check_mark: | | | |

### Distributions

Expand All @@ -99,6 +100,7 @@ Additionally, we have designed every element of the Stack such that APIs as well
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/tgi.html) |
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/together.html) |
| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/fireworks.html) |
| Nutanix | [distribution-nutanix](https://hub.docker.com/repository/docker/jinanz/distribution-nutanix/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/remote_hosted_distro/nutanix.html) |

## Installation

Expand Down
1 change: 1 addition & 0 deletions distributions/nutanix/build.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
../../llama_stack/templates/nutanix/build.yaml
15 changes: 15 additions & 0 deletions distributions/nutanix/compose.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
services:
llamastack:
image: distribution-nutanix
volumes:
- ~/.llama:/root/.llama
- ./run.yaml:/root/llamastack-run-nutanix.yaml
ports:
- "5000:5000"
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-nutanix.yaml"
deploy:
restart_policy:
condition: on-failure
delay: 3s
max_attempts: 5
window: 60s
51 changes: 51 additions & 0 deletions distributions/nutanix/run.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
version: '2'
image_name: nutanix
docker_image: null
conda_env: nutanix
apis:
- agents
- inference
- memory
- safety
- telemetry
providers:
inference:
- provider_id: nutanix
provider_type: remote::nutanix
config:
url: https://ai.nutanix.com/api/v1
api_key: ${env.NUTANIX_API_KEY}
memory:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nutanix}/faiss_store.db
safety:
- provider_id: nutanix
provider_type: remote::nutanix
config: {}
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nutanix}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nutanix}/registry.db
models: []
shields: []
memory_banks: []
datasets: []
scoring_fns: []
eval_tasks: []
1 change: 1 addition & 0 deletions docs/source/distributions/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ If so, we suggest:
- **Do you have an API key for a remote inference provider like Fireworks, Together, etc.?** If so, we suggest:
- {dockerhub}`distribution-together` ([Guide](remote_hosted_distro/index))
- {dockerhub}`distribution-fireworks` ([Guide](remote_hosted_distro/index))
- {dockerhub}`distribution-nutanix` ([Guide](remote_hosted_distro/index))

- **Do you want to run Llama Stack inference on your iOS / Android device** If so, we suggest:
- [iOS SDK](ondevice_distro/ios_sdk)
Expand Down
1 change: 1 addition & 0 deletions docs/source/distributions/remote_hosted_distro/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Remote-Hosted distributions are available endpoints serving Llama Stack API that
|-------------|----------|-----------|---------|---------|---------|------------|
| Together | [https://llama-stack.together.ai](https://llama-stack.together.ai) | remote::together | meta-reference | remote::weaviate | meta-reference | meta-reference |
| Fireworks | [https://llamastack-preview.fireworks.ai](https://llamastack-preview.fireworks.ai) | remote::fireworks | meta-reference | remote::weaviate | meta-reference | meta-reference |
| Nutanix | [https://llamastack-preview.nutanix.ai](https://llamastack-preview.nutanix.ai) | remote::nutanix | meta-reference | meta-reference | meta-reference | meta-reference |

## Connecting to Remote-Hosted Distributions

Expand Down
11 changes: 11 additions & 0 deletions llama_stack/providers/registry/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,15 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="nutanix",
pip_packages=[
"openai",
],
module="llama_stack.providers.remote.inference.nutanix",
config_class="llama_stack.providers.remote.inference.nutanix.NutanixImplConfig",
),
),
]
18 changes: 18 additions & 0 deletions llama_stack/providers/remote/inference/nutanix/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Nutanix, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from .config import NutanixImplConfig


async def get_adapter_impl(config: NutanixImplConfig, _deps):
from .nutanix import NutanixInferenceAdapter

assert isinstance(
config, NutanixImplConfig
), f"Unexpected config type: {type(config)}"
impl = NutanixInferenceAdapter(config)
await impl.initialize()
return impl
29 changes: 29 additions & 0 deletions llama_stack/providers/remote/inference/nutanix/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Nutanix, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import Any, Dict, Optional

from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field


@json_schema_type
class NutanixImplConfig(BaseModel):
url: str = Field(
default="https://ai.nutanix.com/api/v1",
description="The URL of the Nutanix AI Endpoint",
)
api_key: Optional[str] = Field(
default=None,
description="The API key to the Nutanix AI Endpoint",
)

@classmethod
def sample_run_config(cls) -> Dict[str, Any]:
return {
"url": "https://ai.nutanix.com/api/v1",
"api_key": "${env.NUTANIX_API_KEY}",
}
147 changes: 147 additions & 0 deletions llama_stack/providers/remote/inference/nutanix/nutanix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) Nutanix, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from typing import AsyncGenerator

from openai import OpenAI

from llama_models.llama3.api.chat_format import ChatFormat

from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer

from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)

from .config import NutanixImplConfig


MODEL_ALIASES = [
build_model_alias(
"vllm-llama-3-1",
CoreModelId.llama3_1_8b_instruct.value,
),
]


class NutanixInferenceAdapter(ModelRegistryHelper, Inference):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ashwinb this is almost the same code as fireworks and databricks. what do you think of having a common base class?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattf Yes, I think we need to start consolidating more on the code side.

we have some tests now but we also need to put down some more requirements of when a new inference provider comes in. here are some things we are thinking about:

  • support for structured decoding -- kind of table stakes now
  • proper support for tool calling (either directly or via allow legacy completions API so llama stack can format the prompt)
  • support for vision models

otherwise we cannot claim to the user that "you can just Llama Stack and pick-and-choose any provider and you will get a consistent experience"

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the feedback. While the code does share similarities with Fireworks and Databricks, there are important differences, and we anticipate adding new features that will further differentiate our implementation from those of other vendors.

I believe it may be more efficient for each vendor to maintain their own Llama Stack adapter. The duplication of code within each adapter, in this context, is manageable and can even be beneficial. Adopting a "Do Repeat Yourself" approach for these adapters aligns with maintaining clarity and flexibility, especially given the unique requirements and evolution of individual providers.

That said, I’m open to further discussions if there’s a strong case for a shared base class or alternative approach. Let me know your thoughts!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jinan-zhou thank you for the thoughtful argument. i think you're right that abstracting the providers now is too early. i raise the topic only to start a discuss, not to block or slow your valuable contribution.

def __init__(self, config: NutanixImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ALIASES)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())

async def initialize(self) -> None:
return

async def shutdown(self) -> None:
pass

def _get_client(self) -> OpenAI:
nutanix_api_key = None
if self.config.api_key:
nutanix_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.nutanix_api_key:
raise ValueError(
'Pass Together API Key in the header X-LlamaStack-ProviderData as { "nutanix_api_key": <your api key>}'
)
nutanix_api_key = provider_data.nutanix_api_key

return OpenAI(base_url=self.config.url, api_key=nutanix_api_key)

async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()

async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)

client = self._get_client()
if stream:
return self._stream_chat_completion(request, client)
else:
return await self._nonstream_chat_completion(request, client)

async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> ChatCompletionResponse:
params = self._get_params(request)
r = client.chat.completions.create(**params)
return process_chat_completion_response(r, self.formatter)

async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: OpenAI
) -> AsyncGenerator:
params = self._get_params(request)

async def _to_async_generator():
s = client.chat.completions.create(**params)
for chunk in s:
yield chunk

stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk

def _get_params(self, request: ChatCompletionRequest) -> dict:
params = {
"model": request.model,
"messages": chat_completion_request_to_messages(
request, self.get_llama_model(request.model), return_dict=True
),
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
return params

async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
2 changes: 2 additions & 0 deletions llama_stack/providers/utils/inference/openai_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def get_sampling_options(params: SamplingParams) -> dict:


def text_from_choice(choice) -> str:
if hasattr(choice, "message") and choice.message:
return choice.message.content
if hasattr(choice, "delta") and choice.delta:
return choice.delta.content

Expand Down
9 changes: 7 additions & 2 deletions llama_stack/providers/utils/inference/prompt_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,12 @@ def chat_completion_request_to_model_input_info(
def chat_completion_request_to_messages(
request: ChatCompletionRequest,
llama_model: str,
) -> List[Message]:
return_dict: bool = False,
) -> Union[List[Message], List[Dict[str, str]]]:
"""Reads chat completion request and augments the messages to handle tools.
For eg. for llama_3_1, add system message with the appropriate tools or
add user messsage for custom tools, etc.
If return_dict is set, returns a list of the messages dictionaries instead of objects.
"""
model = resolve_model(llama_model)
if model is None:
Expand All @@ -199,7 +201,10 @@ def chat_completion_request_to_messages(
if fmt_prompt := response_format_prompt(request.response_format):
messages.append(UserMessage(content=fmt_prompt))

return messages
if return_dict:
return [{'role': m.role, 'content': m.content} for m in messages]
else:
return messages


def response_format_prompt(fmt: Optional[ResponseFormat]):
Expand Down
7 changes: 7 additions & 0 deletions llama_stack/templates/nutanix/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Nutanix, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from .nutanix import get_distribution_template # noqa: F401
12 changes: 12 additions & 0 deletions llama_stack/templates/nutanix/build.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: nutanix
distribution_spec:
description: Use Nutanix AI Endpoint for running LLM inference
providers:
inference: remote::nutanix
memory:
- inline::faiss
- remote::chromadb
- remote::pgvector
safety: inline::llama-guard
agents: inline::meta-reference
telemetry: inline::meta-reference
Loading