Skip to content

perf: Optimize word segmentation retrieval #2767

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 4 additions & 32 deletions apps/common/util/ts_vecto_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@

import jieba
import jieba.posseg
from jieba import analyse

from common.util.split_model import group_by

jieba_word_list_cache = [chr(item) for item in range(38, 84)]

Expand Down Expand Up @@ -80,37 +77,12 @@ def get_key_by_word_dict(key, word_dict):


def to_ts_vector(text: str):
# 获取不分词的数据
word_list = get_word_list(text)
# 获取关键词关系
word_dict = to_word_dict(word_list, text)
# 替换字符串
text = replace_word(word_dict, text)
# 分词
filter_word = jieba.analyse.extract_tags(text, topK=100)
result = jieba.lcut(text, HMM=True, use_paddle=True)
# 过滤标点符号
result = [item for item in result if filter_word.__contains__(item) and len(item) < 10]
result_ = [{'word': get_key_by_word_dict(result[index], word_dict), 'index': index} for index in
range(len(result))]
result_group = group_by(result_, lambda r: r['word'])
return " ".join(
[f"{key.lower()}:{','.join([str(item['index'] + 1) for item in result_group[key]][:20])}" for key in
result_group if
not remove_chars.__contains__(key) and len(key.strip()) >= 0])
result = jieba.lcut(text)
return " ".join(result)


def to_query(text: str):
# 获取不分词的数据
word_list = get_word_list(text)
# 获取关键词关系
word_dict = to_word_dict(word_list, text)
# 替换字符串
text = replace_word(word_dict, text)
extract_tags = analyse.extract_tags(text, topK=5, withWeight=True, allowPOS=('ns', 'n', 'vn', 'v', 'eng'))
result = " ".join([get_key_by_word_dict(word, word_dict) for word, score in extract_tags if
not remove_chars.__contains__(word)])
# 删除词库
for word in word_list:
jieba.del_word(word)
extract_tags = jieba.lcut(text)
result = " ".join(extract_tags)
return result
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review and Suggestions

  1. Comments:

    • The comments in the code are mostly clear but could benefit from more specific detail on what each function does.
  2. Imports:

    • Import analyse directly instead of using jieba.analyse. This makes it easier to understand where functions like extract_tags come from.
  3. Functionality:

    • In to_ts_vector, replace steps related to keyword extraction and filtering with a direct call to jieba.lcut.
    • In to_query, similar changes can be made. Directly call jieba.lcut without unnecessary processing.
  4. Optimization:

    • Replace the entire replace_word function call inside both to_ts_vector and to_query. If this function is necessary for specific purposes, consider refactoring its implementation or removing it entirely if no longer needed.
  5. Character Handling:

    • Ensure that remove_chars is properly defined and referenced if used elsewhere in the code.
  6. Consistency:

    • Make sure consistency between how different components handle text preprocessing (e.g., punctuation stripping).

Here's an updated version of the cleaned-up code:

import jieba

# Constants and caching
jieba_word_list_cache = [chr(item) for item in range(38, 84)]

def get_key_by_word_dict(key, word_dict):
    raise NotImplementedError("Not implemented yet")

def to_ts_vector(text: str):
    # Clean up text and split into words
    words = jieba.lcut(text)
    return " ".join(words)

def to_query(text: str):
    # Split text into words
    words = jieba.lcut(text)
    return " ".join(words)

This version removes unneeded logic and focuses on leveraging Jieba’s built-in functionality efficiently.

7 changes: 5 additions & 2 deletions apps/embedding/vector/pg_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from abc import ABC, abstractmethod
from typing import Dict, List

from django.db.models import QuerySet
import jieba
from django.contrib.postgres.search import SearchVector
from django.db.models import QuerySet, Value
from langchain_core.embeddings import Embeddings

from common.db.search import generate_sql_by_query_dict
Expand Down Expand Up @@ -68,7 +70,8 @@ def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_the_task_
source_id=text_list[index].get('source_id'),
source_type=text_list[index].get('source_type'),
embedding=embeddings[index],
search_vector=to_ts_vector(text_list[index]['text'])) for index in
search_vector=SearchVector(Value(to_ts_vector(text_list[index]['text'])))) for
index in
range(0, len(texts))]
if not is_the_task_interrupted():
QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no obvious irregularities in the provided code snippet. However, there are a few suggestions for improvement:

  1. Jieba Import: The import of jieba at line 16 seems unnecessary since it's not used anywhere within the function _batch_save. You can remove this line to simplify the code.

  2. Use of SearchVector Correctly: In the SQL query generation logic starting from line 68, you're trying to create a SearchVector using an EmbeddedFunctionCall, which might not be syntactically correct depending on your database setup. Consider whether you need to use Django's built-in support for searching (e.g., PostgreSQL's full-text search capabilities).

  3. Empty Dictionary Check: Ensure that each dictionary entry has the necessary keys 'source_id', 'source_type', and 'text'. Using conditional statement like index < len(text_list) instead of assuming len(text_lists) == len(texts) is more robust, especially if other parameters might change in future versions.

  4. Error Handling: There isn't any error handling in the function, which could make debugging more difficult. Adding try-except blocks around operations that interact with your database can help catch errors early.

Here is the revised version based on these points:

from abc import ABC, abstractmethod
from typing import Dict, List

from django.db.models import QuerySet

from langchain_core.embeddings import Embeddings

import jieba
from django.contrib.postgres.search import SearchVector
from django.db.models import Q, Value

def _batch_save(self, texts: List[Dict], embeddings: List[Embeddings], is_the_task_interrupted):
    embedding_list = [(Q(source_id=text.get('source_id'), source_type=text.get('source_type')),
                       text['text']) for text in texts]
    
    # Filter out empty entries
    filtered_embeddings = [emb for emb in embedding_list if all(val is not None for val in emb)]
    
    if not is_the_task_interrupted():
        embedding_objects = list(Embedding.from_query(Q(**{' OR '.join(f"{key}={value}" for key, value in q.as_q.items())) 
                                                     for q, _ in filtered_embeddings))
        
        batch_create_data = []
        for q, text in filtered_embeddings:
            obj = embedding_objects.pop(0)
            batch_create_data.append((Value(text), obj.id,))
            
        if batch_create_data:
            EmbeddedObject.objects.bulk_update(batch_create_data, fields=['embedding', 'search_vector'])

# Example usage of filter out invalid entries
texts_with_invalid_entries = [{'id': 1, 'source_id': '', 'text': 'A good example'}]
embeddings_with_valid_entries = ['this is a test', 'another good one']
_text_and_embedding_pairs = zip(texts_with_invalid_entries, embeddings_with_valid_entries)

filtered_pairs = [(t, e) for t, e in _text_and_embedding_pairs if all([str(v) if v != '' else None for v in t.values()])]

print(filtered_pairs)

This should help ensure reliability and maintainability of your code.

Expand Down