Skip to content

Commit e9b5151

Browse files
authored
Shared Executor (#11028)
1 parent 926e4b6 commit e9b5151

File tree

3 files changed

+47
-44
lines changed

3 files changed

+47
-44
lines changed

libs/langchain/langchain/callbacks/tracers/evaluation.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,33 @@
22
from __future__ import annotations
33

44
import logging
5-
from concurrent.futures import Future, ThreadPoolExecutor
6-
from typing import Any, Dict, List, Optional, Sequence, Set, Union
5+
import weakref
6+
from concurrent.futures import Future, wait
7+
from typing import Any, Dict, List, Optional, Sequence, Union
78
from uuid import UUID
89

910
import langsmith
11+
from langsmith import schemas as langsmith_schemas
1012
from langsmith.evaluation.evaluator import EvaluationResult
1113

1214
from langchain.callbacks import manager
1315
from langchain.callbacks.tracers import langchain as langchain_tracer
1416
from langchain.callbacks.tracers.base import BaseTracer
17+
from langchain.callbacks.tracers.langchain import _get_executor
1518
from langchain.callbacks.tracers.schemas import Run
1619

1720
logger = logging.getLogger(__name__)
1821

22+
_TRACERS: weakref.WeakSet[EvaluatorCallbackHandler] = weakref.WeakSet()
23+
24+
25+
def wait_for_all_evaluators() -> None:
26+
"""Wait for all tracers to finish."""
27+
global _TRACERS
28+
for tracer in list(_TRACERS):
29+
if tracer is not None:
30+
tracer.wait_for_futures()
31+
1932

2033
class EvaluatorCallbackHandler(BaseTracer):
2134
"""A tracer that runs a run evaluator whenever a run is persisted.
@@ -24,9 +37,6 @@ class EvaluatorCallbackHandler(BaseTracer):
2437
----------
2538
evaluators : Sequence[RunEvaluator]
2639
The run evaluators to apply to all top level runs.
27-
max_workers : int, optional
28-
The maximum number of worker threads to use for running the evaluators.
29-
If not specified, it will default to the number of evaluators.
3040
client : LangSmith Client, optional
3141
The LangSmith client instance to use for evaluating the runs.
3242
If not specified, a new instance will be created.
@@ -59,7 +69,6 @@ class EvaluatorCallbackHandler(BaseTracer):
5969
def __init__(
6070
self,
6171
evaluators: Sequence[langsmith.RunEvaluator],
62-
max_workers: Optional[int] = None,
6372
client: Optional[langsmith.Client] = None,
6473
example_id: Optional[Union[UUID, str]] = None,
6574
skip_unfinished: bool = True,
@@ -72,11 +81,14 @@ def __init__(
7281
)
7382
self.client = client or langchain_tracer.get_client()
7483
self.evaluators = evaluators
75-
self.max_workers = max_workers or len(evaluators)
76-
self.futures: Set[Future] = set()
84+
self.executor = _get_executor()
85+
self.futures: weakref.WeakSet[Future] = weakref.WeakSet()
7786
self.skip_unfinished = skip_unfinished
7887
self.project_name = project_name
88+
self.logged_feedback: Dict[str, List[langsmith_schemas.Feedback]] = {}
7989
self.logged_eval_results: Dict[str, List[EvaluationResult]] = {}
90+
global _TRACERS
91+
_TRACERS.add(self)
8092

8193
def _evaluate_in_project(self, run: Run, evaluator: langsmith.RunEvaluator) -> None:
8294
"""Evaluate the run in the project.
@@ -120,15 +132,11 @@ def _persist_run(self, run: Run) -> None:
120132
return
121133
run_ = run.copy()
122134
run_.reference_example_id = self.example_id
123-
if self.max_workers > 0:
124-
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
125-
list(
126-
executor.map(
127-
self._evaluate_in_project,
128-
[run_ for _ in range(len(self.evaluators))],
129-
self.evaluators,
130-
)
131-
)
132-
else:
133-
for evaluator in self.evaluators:
134-
self._evaluate_in_project(run_, evaluator)
135+
for evaluator in self.evaluators:
136+
self.futures.add(
137+
self.executor.submit(self._evaluate_in_project, run_, evaluator)
138+
)
139+
140+
def wait_for_futures(self) -> None:
141+
"""Wait for all futures to complete."""
142+
wait(self.futures)

libs/langchain/langchain/callbacks/tracers/langchain.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import weakref
77
from concurrent.futures import Future, ThreadPoolExecutor, wait
88
from datetime import datetime
9-
from typing import Any, Callable, Dict, List, Optional, Set, Union
9+
from typing import Any, Callable, Dict, List, Optional, Union
1010
from uuid import UUID
1111

1212
from langsmith import Client
@@ -21,8 +21,7 @@
2121
_LOGGED = set()
2222
_TRACERS: weakref.WeakSet[LangChainTracer] = weakref.WeakSet()
2323
_CLIENT: Optional[Client] = None
24-
_MAX_EXECUTORS = 10 # TODO: Remove once write queue is implemented
25-
_EXECUTORS: List[ThreadPoolExecutor] = []
24+
_EXECUTOR: Optional[ThreadPoolExecutor] = None
2625

2726

2827
def log_error_once(method: str, exception: Exception) -> None:
@@ -50,6 +49,14 @@ def get_client() -> Client:
5049
return _CLIENT
5150

5251

52+
def _get_executor() -> ThreadPoolExecutor:
53+
"""Get the executor."""
54+
global _EXECUTOR
55+
if _EXECUTOR is None:
56+
_EXECUTOR = ThreadPoolExecutor()
57+
return _EXECUTOR
58+
59+
5360
class LangChainTracer(BaseTracer):
5461
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
5562

@@ -71,21 +78,10 @@ def __init__(
7178
self.project_name = project_name or os.getenv(
7279
"LANGCHAIN_PROJECT", os.getenv("LANGCHAIN_SESSION", "default")
7380
)
74-
if use_threading:
75-
global _MAX_EXECUTORS
76-
if len(_EXECUTORS) < _MAX_EXECUTORS:
77-
self.executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor(
78-
max_workers=1
79-
)
80-
_EXECUTORS.append(self.executor)
81-
else:
82-
self.executor = _EXECUTORS.pop(0)
83-
_EXECUTORS.append(self.executor)
84-
else:
85-
self.executor = None
8681
self.client = client or get_client()
87-
self._futures: Set[Future] = set()
82+
self._futures: weakref.WeakSet[Future] = weakref.WeakSet()
8883
self.tags = tags or []
84+
self.executor = _get_executor() if use_threading else None
8985
global _TRACERS
9086
_TRACERS.add(self)
9187

@@ -229,7 +225,4 @@ def _on_retriever_error(self, run: Run) -> None:
229225

230226
def wait_for_futures(self) -> None:
231227
"""Wait for the given futures to complete."""
232-
futures = list(self._futures)
233-
wait(futures)
234-
for future in futures:
235-
self._futures.remove(future)
228+
wait(self._futures)

libs/langchain/langchain/smith/evaluation/runner_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@
2424
from langsmith.schemas import Dataset, DataType, Example
2525

2626
from langchain.callbacks.manager import Callbacks
27-
from langchain.callbacks.tracers.evaluation import EvaluatorCallbackHandler
28-
from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
27+
from langchain.callbacks.tracers.evaluation import (
28+
EvaluatorCallbackHandler,
29+
wait_for_all_evaluators,
30+
)
31+
from langchain.callbacks.tracers.langchain import LangChainTracer
2932
from langchain.chains.base import Chain
3033
from langchain.evaluation.loading import load_evaluator
3134
from langchain.evaluation.schema import (
@@ -915,7 +918,6 @@ def _prepare_run_on_dataset(
915918
EvaluatorCallbackHandler(
916919
evaluators=run_evaluators or [],
917920
client=client,
918-
max_workers=0,
919921
example_id=example.id,
920922
),
921923
progress_bar,
@@ -934,7 +936,7 @@ def _collect_test_results(
934936
configs: List[RunnableConfig],
935937
project_name: str,
936938
) -> TestResult:
937-
wait_for_all_tracers()
939+
wait_for_all_evaluators()
938940
all_eval_results = {}
939941
for c in configs:
940942
for callback in cast(list, c["callbacks"]):

0 commit comments

Comments
 (0)