Description
What happened + What you expected to happen
I am trying to embed 5gb of pdf's and other files. Then I send them into a vector database (milvus) It slowly builds up to using 1TB of object_store_memory, at which point it then fails. No matter what I do to the parallelism, batch size, etc, I can't stop it from happening.
The odd extract function is done because I need to use textract to read in the file as I can't just use its bytes. This is because textract has good text extraction from my testing.
Versions / Dependencies
accelerate==0.25.0
aiofiles==23.2.1
aiohttp==3.9.1
aiohttp-cors==0.7.0
aiorwlock==1.3.0
aiosignal==1.3.1
annotated-types==0.6.0
anyio==3.7.1
argcomplete==1.10.3
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
asgiref==3.7.2
asyncer==0.0.2
attrs==23.1.0
backoff==2.2.1
bcrypt==4.1.1
beautifulsoup4==4.8.2
bidict==0.22.1
blessed==1.20.0
blis==0.7.11
cachetools==5.3.2
catalogue==2.0.10
certifi==2023.11.17
cffi==1.16.0
chainlit==0.7.700
chardet==3.0.4
charset-normalizer==3.3.2
chroma-hnswlib==0.7.3
chromadb==0.4.19
click==8.1.7
cloudpathlib==0.16.0
coloredlogs==15.0.1
colorful==0.5.5
compressed-rtf==1.0.6
confection==0.1.4
cymem==2.0.8
dataclasses-json==0.5.14
Deprecated==1.2.14
distlib==0.3.8
distro==1.8.0
docx2txt==0.8
ebcdic==1.1.1
environs==9.5.0
extract-msg==0.28.7
fastapi==0.100.1
fastapi-socketio==0.0.10
filelock==3.13.1
filetype==1.2.0
flatbuffers==23.5.26
frozenlist==1.4.0
fsspec==2023.12.2
google-api-core==2.15.0
google-auth==2.25.2
googleapis-common-protos==1.62.0
gpustat==1.1.1
greenlet==3.0.2
grpcio==1.58.0
gunicorn==21.2.0
h11==0.14.0
httpcore==0.17.3
httptools==0.6.1
httpx==0.24.1
huggingface-hub==0.19.4
humanfriendly==10.0
idna==3.6
IMAPClient==2.1.0
importlib-metadata==6.11.0
importlib-resources==6.1.1
install==1.3.5
Jinja2==3.1.2
joblib==1.3.2
jsonpatch==1.33
jsonpointer==2.4
jsonschema==4.20.0
jsonschema-specifications==2023.12.1
kubernetes==28.1.0
langchain==0.0.350
langchain-community==0.0.3
langchain-core==0.1.0
langcodes==3.3.0
langsmith==0.0.69
Lazify==0.4.0
lxml==4.9.3
MarkupSafe==2.1.3
marshmallow==3.20.1
mergedeep==1.3.4
minio==7.2.3
mmh3==4.0.1
monotonic==1.6
mpmath==1.3.0
msgpack==1.0.7
multidict==6.0.4
murmurhash==1.0.10
mypy-extensions==1.0.0
nest-asyncio==1.5.8
networkx==3.2.1
nltk==3.8.1
numpy==1.26.2
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-ml-py==12.535.133
nvidia-nccl-cu12==2.18.1
nvidia-nvjitlink-cu12==12.3.101
nvidia-nvtx-cu12==12.1.105
oauthlib==3.2.2
olefile==0.47
onnxruntime==1.16.3
openai==1.3.9
opencensus==0.11.4
opencensus-context==0.1.3
opentelemetry-api==1.21.0
opentelemetry-exporter-otlp==1.21.0
opentelemetry-exporter-otlp-proto-common==1.21.0
opentelemetry-exporter-otlp-proto-grpc==1.21.0
opentelemetry-exporter-otlp-proto-http==1.21.0
opentelemetry-instrumentation==0.42b0
opentelemetry-instrumentation-asgi==0.42b0
opentelemetry-instrumentation-fastapi==0.42b0
opentelemetry-proto==1.21.0
opentelemetry-sdk==1.21.0
opentelemetry-semantic-conventions==0.42b0
opentelemetry-util-http==0.42b0
overrides==7.4.0
packaging==23.2
pandas==2.1.4
pdfminer.six==20191110
Pillow==10.1.0
platformdirs==3.11.0
posthog==3.1.0
preshed==3.0.9
prometheus-client==0.19.0
protobuf==4.25.1
psutil==5.9.6
pulsar-client==3.3.0
py-spy==0.3.14
pyarrow==14.0.2
pyasn1==0.5.1
pyasn1-modules==0.3.0
pycparser==2.21
pycryptodome==3.19.0
pydantic==2.5.2
pydantic_core==2.14.5
PyJWT==2.8.0
pymilvus==2.3.5
PyPika==0.48.9
python-dateutil==2.8.2
python-dotenv==1.0.0
python-engineio==4.8.0
python-graphql-client==0.4.3
python-multipart==0.0.6
python-pptx==0.6.23
python-socketio==5.10.0
pytz==2023.3.post1
PyYAML==6.0.1
ray==2.9.0
referencing==0.32.0
regex==2023.10.3
requests==2.31.0
requests-oauthlib==1.3.1
rpds-py==0.16.2
rsa==4.9
safetensors==0.4.1
scikit-learn==1.3.2
scipy==1.11.4
sentence-transformers==2.2.2
sentencepiece==0.1.99
simple-websocket==1.0.0
six==1.16.0
smart-open==6.4.0
sniffio==1.3.0
sortedcontainers==2.4.0
soupsieve==2.5
spacy==3.7.2
spacy-legacy==3.0.12
spacy-loggers==1.0.5
SpeechRecognition==3.8.1
SQLAlchemy==2.0.23
srsly==2.4.8
starlette==0.27.0
sympy==1.12
syncer==2.0.3
tarsafe==0.0.5
tenacity==8.2.3
textract==1.6.5
thinc==8.2.1
threadpoolctl==3.2.0
tokenizers==0.15.0
tomli==2.0.1
torch==2.1.1
torchvision==0.16.1
tqdm==4.66.1
transformers==4.36.0
triton==2.1.0
typer==0.9.0
typing-inspect==0.9.0
typing_extensions==4.9.0
tzdata==2023.3
tzlocal==5.2
ujson==5.9.0
uptrace==1.21.0
urllib3==1.26.18
uvicorn==0.23.2
uvloop==0.19.0
virtualenv==20.21.0
wasabi==1.1.2
watchfiles==0.20.0
wcwidth==0.2.12
weasel==0.3.4
websocket-client==1.7.0
websockets==12.0
wrapt==1.16.0
wsproto==1.2.0
xlrd==1.2.0
XlsxWriter==3.1.9
yarl==1.9.4
zipp==3.17.0
Reproduction script
import binascii
import io
from typing import List
import textract
import ray
ray.init(
runtime_env={"pip": ["langchain", "sentence_transformers", "transformers"]}
)
ds = ray.data.read_binary_files(['local://filepath'],include_paths=True, parallelism=256)
def extract(file):
try:
file['item']=textract.process(str(file['path']), encoding="utf8").decode("utf-8")
except:
file['item'] = ""
return file
# We use `flat_map` as `convert_to_text` has a 1->N relationship.
# It produces N strings for each PDF (one string per page).
# Use `map` for 1->1 relationship.
ds = ds.map(extract)
from langchain.text_splitter import RecursiveCharacterTextSplitter
def split_text(page_dict):
# Use chunk_size of 1000.
# We felt that the answer we would be looking for would be
# around 200 words, or around 1000 characters.
# This parameter can be modified based on your documents and use case.
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=256, chunk_overlap=100, length_function=len
)
split_text: List[str] = text_splitter.split_text(page_dict['item'])
split_text = [dict( page_dict,**{"item":text.replace("\n", " ")}) for text in split_text]
return split_text
# # We use `flat_map` as `split_text` has a 1->N relationship.
# # It produces N output chunks for each input string.
# # Use `map` for 1->1 relationship.
ds = ds.flat_map(split_text,num_cpus=5)
# print(ds.take(1))
from sentence_transformers import SentenceTransformer
# Use LangChain's default model.
# This model can be changed depending on your task.
model_name = "llmrails/ember-v1"
# We use sentence_transformers directly to provide a specific batch size.
# LangChain's HuggingfaceEmbeddings can be used instead once https://github.com/hwchase17/langchain/pull/3914
# is merged.
class Embed:
def __init__(self):
# Specify "cuda" to move the model to GPU.
self.transformer = SentenceTransformer(model_name, device="cuda")
def __call__(self, text_batch):
# We manually encode using sentence_transformer since LangChain
# HuggingfaceEmbeddings does not support specifying a batch size yet.
embeddings = self.transformer.encode(
text_batch['item'],
batch_size=10, # Large batch size to maximize GPU utilization.
device="cuda",
).tolist()
return {"vector": embeddings,"text":text_batch['item']}
# Use `map_batches` since we want to specify a batch size to maximize GPU utilization.
ds = ds.map_batches(
Embed,
# Large batch size to maximize GPU utilization.
# Too large a batch size may result in GPU running out of memory.
# If the chunk size is increased, then decrease batch size.
# If the chunk size is decreased, then increase batch size.
batch_size=10, # Large batch size to maximize GPU utilization.
compute=ray.data.ActorPoolStrategy(min_size=8, max_size=8),
num_gpus=1, # 1 GPU for each actor.
# num_cpus=8,
)
# text_and_embeddings = []
# for output in ds.iter_rows():
# insert_dict['text'].append(output['text'])
# insert_dict['vector'].append(output['vector'])
# # print(ds.take(1))
from pymilvus import (
connections,
utility,
Collection,
db
)
connections.connect("default", host='localhost', port=19530)
import milv
col = milv.collection.create_collection(name="milvus_collection")
FIELDS = ['text','vector']
import gc
for batch in ds.iter_batches(batch_size=30, batch_format="numpy"):
col = Collection("milvus_collection")
res = col.insert([batch['text'],batch['vector']])
col.flush()
del batch
gc.collect()
Issue Severity
High: It blocks me from completing my task.