-
Notifications
You must be signed in to change notification settings - Fork 29
/
base.py
79 lines (63 loc) · 2.54 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from abc import ABC, ABCMeta
from concurrent.futures import ThreadPoolExecutor
from typing import Any, List, Optional
import pandas as pd
from continuous_eval.llm_factory import DefaultLLM, LLMInterface
from continuous_eval.utils.telemetry import telemetry
class MetricDecoratorMeta(ABCMeta, type):
def __new__(cls, name, bases, dct):
for attr, value in dct.items():
if callable(value) and attr == '__call__':
dct[attr] = telemetry.metric_telemetry(value)
elif callable(value) and attr == 'batch':
pass
# dct[attr] = telemetry.batch_metric_telemetry(value)
return type.__new__(cls, name, bases, dct)
class Metric(ABC, metaclass=MetricDecoratorMeta):
def __init__(self) -> None:
super().__init__()
self._overloaded_params = None
self.max_workers = 32
def use(self, **kwargs) -> "Metric":
self._overloaded_params = kwargs
return self
@property
def overloaded_params(self):
return self._overloaded_params
def __call__(self, **kwargs):
# Implement this method in the subclass
raise NotImplementedError()
def batch(self, **kwargs) -> Any:
kwargs_ = [{key: kwargs[key][i] for key in kwargs} for i in range(len(next(iter(kwargs.values()))))]
if self.max_workers <= 1:
return [self(**kw) for kw in kwargs_]
instances = []
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
future_instances = [executor.submit(lambda kw: self(**kw), kw) for kw in kwargs_]
for future in future_instances:
instances.append(future.result())
return instances
def aggregate(self, results: List[Any]) -> Any:
# Default implementation
sanitize = lambda results: [{k: v for k, v in r.items() if not isinstance(v, (list, str))} for r in results]
agg = pd.DataFrame(sanitize(results))
return agg.mean().to_dict()
@property
def name(self):
return self.__class__.__name__
def asdict(self):
return {
"__class__": self.__class__.__name__,
"name": self.name,
}
class LLMBasedMetric(Metric):
"""
Base class for all LLM based metrics.
"""
def __init__(self, model: Optional[LLMInterface] = None):
super().__init__()
if model is None:
self._llm = DefaultLLM()
else:
self._llm = model
assert isinstance(self._llm, LLMInterface), "model must be an instance of LLMInterface."