Skip to content

Commit

Permalink
Merge branch 'main' into hy/type_hints_all
Browse files Browse the repository at this point in the history
  • Loading branch information
yihong0618 committed Dec 20, 2024
2 parents a9df950 + 463fbe2 commit 2c89106
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,10 @@ def _invoke(
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding

# calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,10 @@ def _invoke(
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding

# calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
- gemini-2.0-flash-exp
- gemini-2.0-flash-thinking-exp-1219
- gemini-1.5-pro
- gemini-1.5-pro-latest
- gemini-1.5-pro-001
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
model: gemini-2.0-flash-thinking-exp-1219
label:
en_US: Gemini 2.0 Flash Thinking Exp 1219
model_type: llm
features:
- agent-thought
- vision
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ def _invoke(
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding

# calc usage
usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,10 @@ def _invoke(
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()
embedding = (average / np.linalg.norm(average)).tolist()
if np.isnan(embedding).any():
raise ValueError("Normalized embedding is nan please try again")
embeddings[i] = embedding

usage = self._calc_response_usage(model=model, credentials=credentials, tokens=used_tokens)

Expand Down
4 changes: 3 additions & 1 deletion api/core/rag/embedding/cached_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,9 @@ def embed_query(self, text: str) -> list[float]:
)

embedding_results = embedding_result.embeddings[0]
embedding_results = (np.array(embedding_results) / np.linalg.norm(np.array(embedding_results))).tolist()
embedding_results = (embedding_results / np.linalg.norm(embedding_results)).tolist()
if np.isnan(embedding_results).any():
raise ValueError("Normalized embedding is nan please try again")
except Exception as ex:
if dify_config.DEBUG:
logging.exception(f"Failed to embed query text '{text[:10]}...({len(text)} chars)'")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def queue_prompt(self, client_id: str, prompt: dict) -> str:
def open_websocket_connection(self) -> tuple[WebSocket, str]:
client_id = str(uuid.uuid4())
ws = WebSocket()
ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}"
ws_protocol = "ws"
if self.base_url.scheme == "https":
ws_protocol = "wss"
ws_address = f"{ws_protocol}://{self.base_url.authority}/ws?clientId={client_id}"
ws.connect(ws_address)
return ws, client_id

Expand Down

0 comments on commit 2c89106

Please sign in to comment.