22from __future__ import annotations
33
44import logging
5- from concurrent .futures import Future , ThreadPoolExecutor
6- from typing import Any , Dict , List , Optional , Sequence , Set , Union
5+ import weakref
6+ from concurrent .futures import Future , wait
7+ from typing import Any , Dict , List , Optional , Sequence , Union
78from uuid import UUID
89
910import langsmith
11+ from langsmith import schemas as langsmith_schemas
1012from langsmith .evaluation .evaluator import EvaluationResult
1113
1214from langchain .callbacks import manager
1315from langchain .callbacks .tracers import langchain as langchain_tracer
1416from langchain .callbacks .tracers .base import BaseTracer
17+ from langchain .callbacks .tracers .langchain import _get_executor
1518from langchain .callbacks .tracers .schemas import Run
1619
1720logger = logging .getLogger (__name__ )
1821
22+ _TRACERS : weakref .WeakSet [EvaluatorCallbackHandler ] = weakref .WeakSet ()
23+
24+
25+ def wait_for_all_evaluators () -> None :
26+ """Wait for all tracers to finish."""
27+ global _TRACERS
28+ for tracer in list (_TRACERS ):
29+ if tracer is not None :
30+ tracer .wait_for_futures ()
31+
1932
2033class EvaluatorCallbackHandler (BaseTracer ):
2134 """A tracer that runs a run evaluator whenever a run is persisted.
@@ -24,9 +37,6 @@ class EvaluatorCallbackHandler(BaseTracer):
2437 ----------
2538 evaluators : Sequence[RunEvaluator]
2639 The run evaluators to apply to all top level runs.
27- max_workers : int, optional
28- The maximum number of worker threads to use for running the evaluators.
29- If not specified, it will default to the number of evaluators.
3040 client : LangSmith Client, optional
3141 The LangSmith client instance to use for evaluating the runs.
3242 If not specified, a new instance will be created.
@@ -59,7 +69,6 @@ class EvaluatorCallbackHandler(BaseTracer):
5969 def __init__ (
6070 self ,
6171 evaluators : Sequence [langsmith .RunEvaluator ],
62- max_workers : Optional [int ] = None ,
6372 client : Optional [langsmith .Client ] = None ,
6473 example_id : Optional [Union [UUID , str ]] = None ,
6574 skip_unfinished : bool = True ,
@@ -72,11 +81,14 @@ def __init__(
7281 )
7382 self .client = client or langchain_tracer .get_client ()
7483 self .evaluators = evaluators
75- self .max_workers = max_workers or len ( evaluators )
76- self .futures : Set [Future ] = set ()
84+ self .executor = _get_executor ( )
85+ self .futures : weakref . WeakSet [Future ] = weakref . WeakSet ()
7786 self .skip_unfinished = skip_unfinished
7887 self .project_name = project_name
88+ self .logged_feedback : Dict [str , List [langsmith_schemas .Feedback ]] = {}
7989 self .logged_eval_results : Dict [str , List [EvaluationResult ]] = {}
90+ global _TRACERS
91+ _TRACERS .add (self )
8092
8193 def _evaluate_in_project (self , run : Run , evaluator : langsmith .RunEvaluator ) -> None :
8294 """Evaluate the run in the project.
@@ -120,15 +132,11 @@ def _persist_run(self, run: Run) -> None:
120132 return
121133 run_ = run .copy ()
122134 run_ .reference_example_id = self .example_id
123- if self .max_workers > 0 :
124- with ThreadPoolExecutor (max_workers = self .max_workers ) as executor :
125- list (
126- executor .map (
127- self ._evaluate_in_project ,
128- [run_ for _ in range (len (self .evaluators ))],
129- self .evaluators ,
130- )
131- )
132- else :
133- for evaluator in self .evaluators :
134- self ._evaluate_in_project (run_ , evaluator )
135+ for evaluator in self .evaluators :
136+ self .futures .add (
137+ self .executor .submit (self ._evaluate_in_project , run_ , evaluator )
138+ )
139+
140+ def wait_for_futures (self ) -> None :
141+ """Wait for all futures to complete."""
142+ wait (self .futures )
0 commit comments