Skip to content

Commit d92fe97

Browse files
besaleliangpt
andauthored
docs: transforms API (#134)
Transforms docs! --------- Co-authored-by: angpt <anushrigupta@gmail.com>
1 parent f179b4b commit d92fe97

File tree

1 file changed

+197
-1
lines changed

1 file changed

+197
-1
lines changed

docs/transforms/index.md

Lines changed: 197 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,199 @@
11
# Transforms
22

3-
🚧 Coming Soon 🚧
3+
Transforms allow you to post-process model outputs after ONNX inference and before returning results. They run inside the model binary, operating directly on tensors for high performance.
4+
5+
Transforms run on Lua 5.4 in a sandboxed environment. As of right now, the transforms feature does not support LuaJIT.
6+
7+
## Why Use Transforms?
8+
9+
Common use cases:
10+
- **Normalize embeddings** for cosine similarity
11+
- **Apply softmax** to convert logits to probabilities
12+
- **Pool embeddings** to create sentence representations
13+
- **Scale outputs** for specific downstream tasks
14+
15+
## Getting Started
16+
17+
A transform is a Lua script that defines a `Postprocess` function:
18+
19+
```lua
20+
---@param arr Tensor
21+
---@return Tensor
22+
function Postprocess(arr, ...)
23+
-- your postprocessing logic
24+
return tensor
25+
end
26+
```
27+
28+
With a handful of exceptions, the `Postprocess` function must return a `Tensor` with the exact same shape as the input `Tensor` provided for that model type. The exceptions are as follows:
29+
30+
- Embedding and sentence embedding models can modify the length of `hidden` (useful for matryoshka embeddings)
31+
- Sentence embeddings are given a `Tensor` of shape `[batch_size, seq_len, hidden]` and attention mask of `[batch_size, seq_len]`, and must return a `Tensor` of shape `[batch_size, hidden]`. In other words, it expects a pooling operation along dimension `seq_len`.
32+
33+
!!! note "Note on indexing"
34+
Lua is 1-indexed, meaning that it starts counting at 1 instead of 0. The `Tensor` API reflects this, meaning that you must count your axes and indices starting at 1 instead of 0.
35+
36+
We provide a built-in API for standard tensor operations. To learn more, check out our [Tensor API reference page](reference). You can find the stub file [here](https://github.com/mozilla-ai/encoderfile/blob/main/encoderfile-core/stubs/lua/tensor.lua).
37+
38+
If you don't see an op that you need, please don't hesitate to [create an issue](https://github.com/mozilla-ai/encoderfile/issues) on Github.
39+
40+
## Input Signatures
41+
42+
The input signature of `Postprocess` depends on the type of model being used.
43+
44+
### Embedding
45+
46+
```lua
47+
--- input: 3d tensor of shape [batch_size, seq_len, hidden]
48+
---@param arr Tensor
49+
---output: 3d tensor of shape [batch_size, seq_len, hidden]
50+
---@return Tensor
51+
function Postprocess(arr)
52+
-- your postprocessing logic
53+
return tensor
54+
end
55+
```
56+
57+
### Sequence Classification
58+
59+
```lua
60+
--- input: 2d tensor of shape [batch_size, n_labels]
61+
---@param arr Tensor
62+
---output: 2d tensor of shape [batch_size, n_labels]
63+
---@return Tensor
64+
function Postprocess(arr)
65+
-- your postprocessing logic
66+
return tensor
67+
end
68+
```
69+
70+
### Token Classification
71+
72+
```lua
73+
--- input: 3d tensor of shape [batch_size, seq_len, n_labels]
74+
---@param arr Tensor
75+
---output: 3d tensor of shape [batch_size, seq_len, n_labels]
76+
---@return Tensor
77+
function Postprocess(arr)
78+
-- your postprocessing logic
79+
return tensor
80+
end
81+
```
82+
83+
### Sentence Embedding
84+
85+
86+
!!! note "Mean Pooling"
87+
To mean-pool embeddings, you can use the `Tensor:mean_pool` function like this: `tensor:mean_pool(mask)`.
88+
89+
```lua
90+
--- input: 3d tensor of shape [batch_size, seq_len, hidden]
91+
---@param arr Tensor
92+
-- input: 2d tensor of shape [batch_size, seq_len]
93+
-- This is automatically provided to the function and is equivalent to 🤗 transformer's attention_mask.
94+
---@param mask Tensor
95+
---output: 2d tensor of shape [batch_size, hidden]
96+
---@return Tensor
97+
function Postprocess(arr, mask)
98+
-- your postprocessing logic
99+
return tensor
100+
end
101+
```
102+
103+
## Typical Transform Patterns
104+
105+
Most transforms fall into one of 3 patterns:
106+
107+
### 1. Elementwise Transforms
108+
109+
Safe: they preserve shape automatically.
110+
111+
Examples:
112+
113+
- scaling (`tensor * 1.5`)
114+
- activation functions (`tensor:exp()`)
115+
116+
### 2. Normalization Across Axis
117+
118+
These also preserve shape.
119+
120+
Examples:
121+
122+
- Lp normalization: (`tensor:lp_normalize(p, axis)`)
123+
- subtracting mean per batch or per token
124+
- applying softmax across a specific dimension (`tensor:softmax(2)`)
125+
126+
### 3. Mask-aware adjustments
127+
128+
When working with sentence embedding models:
129+
130+
```lua
131+
function Postprocess(arr, mask)
132+
-- embeddings: [batch, seq, hidden]
133+
-- mask: [batch, seq]
134+
135+
-- operations here must output [batch, hidden]
136+
return ...
137+
end
138+
```
139+
140+
## Best Practices
141+
142+
!!! warning "Performance Implications"
143+
Transforms run synchronously during inference, so expensive Lua-side loops will increase latency. If you don't see an op that you need, please don't hesitate to [create an issue](https://github.com/mozilla-ai/encoderfile/issues) on Github.
144+
145+
A typical transform follows this structure:
146+
147+
```lua
148+
function Postprocess(arr, ...)
149+
-- Step 1: apply elementwise or axis-based operations
150+
local modified = arr:exp() -- example
151+
152+
-- Step 2: ensure the output shape matches the input shape
153+
-- (all built-in ops described in the Tensor API preserve shape)
154+
155+
return modified
156+
end
157+
```
158+
159+
## Debugging Transforms
160+
161+
You can inspect shape and values using:
162+
163+
```lua
164+
print("ndim:", t:ndim())
165+
print("len:", #t)
166+
print(tostring(t))
167+
```
168+
169+
Errors typically fall into:
170+
171+
- axis out of range
172+
→ axis must be 1-indexed and ≤ tensor rank
173+
174+
- broadcasting errors
175+
→ the two shapes are incompatible
176+
177+
- returned value is not a tensor
178+
→ must return a Tensor userdata object
179+
180+
- shape mismatch
181+
→ you modified rank or dimensions
182+
183+
## Configuration
184+
185+
Transforms are embedded at build time. You can specify them in your config.yml either as a file path or inline.
186+
187+
```yml
188+
transform:
189+
path: path/to/your/transform/here
190+
```
191+
192+
Or, they can be passed inline:
193+
```yml
194+
transform: |
195+
function Postprocess(arr)
196+
...
197+
return arr
198+
```
199+

0 commit comments

Comments
 (0)