-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaugment_graph.py
More file actions
219 lines (174 loc) · 7.57 KB
/
augment_graph.py
File metadata and controls
219 lines (174 loc) · 7.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
#!/usr/bin/env python3
"""Augment the existing temporal graph with LLM-extracted features.
Loads the current graph (.pt), merges LLM features from parquet files,
and saves a new graph with expanded node features. Tickers without LLM
features get zeros (the models handle missing features fine).
Usage:
python augment_graph.py
python augment_graph.py --input data/graphs/sp500_2015_2024.pt --output data/graphs/sp500_2015_2024_llm.pt
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
import numpy as np
import pandas as pd
sys.path.insert(0, str(Path(__file__).parent / "src"))
from fingraph.graph.graph_store import load_graph, save_graph
from fingraph.graph.graph_builder import GraphSnapshot, TemporalGraph
BASE_DIR = Path(__file__).parent
LLM_DIR = BASE_DIR / "data" / "features" / "language"
# LLM feature columns we want to add to the graph
#
# Dropped (broken — LLM hallucinated large numbers instead of mention counts):
# customer_count, supplier_count, competitor_count
#
# Dropped (low variance — clustered around 0.5, nearly constant):
# sentiment_confidence_business, sentiment_confidence_risk_factors,
# sentiment_confidence_mda, risk_max_severity
#
# Keeping 8 high-signal features with good distributions:
LLM_FEATURE_COLS = [
"sentiment_business", # [-0.95, 0.99] good spread
"sentiment_risk_factors", # [-1.0, 0.98] good spread
"sentiment_mda", # [-0.9, 0.98] good spread
"hedging_risk_factors", # [0, 1] good spread
"hedging_mda", # [0, 1] good spread
"risk_factor_count", # [0, 56] reasonable integers
"entity_count_business", # [0, 69] mostly 0-4
"entity_count_mda", # [0, 28] mostly 1-3
]
def load_llm_features() -> pd.DataFrame:
"""Load and merge all LLM feature parquets."""
dfs = []
# Main file (from original single-worker run)
main_file = LLM_DIR / "llm_features.parquet"
if main_file.exists():
df = pd.read_parquet(main_file)
print(f" Main file: {len(df)} records")
dfs.append(df)
# Worker files
for i in range(10):
worker_file = LLM_DIR / f"llm_features_worker{i}.parquet"
if worker_file.exists():
df = pd.read_parquet(worker_file)
print(f" Worker {i}: {len(df)} records")
dfs.append(df)
if not dfs:
raise FileNotFoundError("No LLM feature files found!")
merged = pd.concat(dfs, ignore_index=True)
# Deduplicate, keeping last (most recent extraction)
merged = merged.drop_duplicates(subset=["ticker", "year", "quarter"], keep="last")
# Ensure all expected columns exist (fill with 0 if missing)
for col in LLM_FEATURE_COLS:
if col not in merged.columns:
merged[col] = 0.0
# Fill NaN values with 0
for col in LLM_FEATURE_COLS:
merged[col] = merged[col].fillna(0.0)
print(f" Merged: {len(merged)} records, {merged['ticker'].nunique()} tickers")
return merged
def quarter_label_to_year_q(label: str):
"""Parse '2020Q3' -> (2020, 3)."""
year = int(label[:4])
q = int(label[-1])
return year, q
def augment_graph(graph: TemporalGraph, llm_df: pd.DataFrame) -> TemporalGraph:
"""Add RAW LLM features to each snapshot (no z-score normalization).
The original 15 features are raw (not normalized), so LLM features
should also be raw to maintain consistent scale treatment. The model's
input projection layer handles scale differences.
Missing values are filled with per-quarter MEDIAN from available data
(not zero — zero would mean 'neutral sentiment' which is wrong for
nodes with no data).
"""
new_feature_names = graph.feature_names + LLM_FEATURE_COLS
num_llm_features = len(LLM_FEATURE_COLS)
new_snapshots = []
total_filled = 0
total_nodes = 0
for snap in graph.snapshots:
year, q = quarter_label_to_year_q(snap.quarter_label)
# Get LLM features for this quarter
quarter_llm = llm_df[
(llm_df["year"] == year) & (llm_df["quarter"] == q)
]
# First pass: collect values to compute per-quarter medians
col_values = {col: [] for col in LLM_FEATURE_COLS}
node_data = {} # idx -> feature values
for _, row in quarter_llm.iterrows():
ticker = row["ticker"]
if ticker in snap.ticker_to_idx:
idx = snap.ticker_to_idx[ticker]
vals = []
for j, col in enumerate(LLM_FEATURE_COLS):
val = row.get(col, None)
if val is not None and not (isinstance(val, float) and np.isnan(val)):
v = float(val)
col_values[col].append(v)
vals.append(v)
else:
vals.append(None)
node_data[idx] = vals
# Compute per-quarter medians for fill values
medians = np.zeros(num_llm_features, dtype=np.float32)
for j, col in enumerate(LLM_FEATURE_COLS):
if col_values[col]:
medians[j] = np.median(col_values[col])
# Build feature matrix: fill ALL nodes with median, override with real data
llm_features = np.tile(medians, (snap.num_nodes, 1))
filled = 0
for idx, vals in node_data.items():
for j, v in enumerate(vals):
if v is not None:
llm_features[idx, j] = v
filled += 1
total_filled += filled
total_nodes += snap.num_nodes
# Concatenate original features + LLM features
new_node_features = np.concatenate(
[snap.node_features, llm_features], axis=1
)
new_snap = GraphSnapshot(
quarter=snap.quarter,
quarter_label=snap.quarter_label,
ticker_to_idx=snap.ticker_to_idx,
node_features=new_node_features,
feature_names=new_feature_names,
edge_index=snap.edge_index,
edge_types=snap.edge_types,
edge_weights=snap.edge_weights,
labels=snap.labels,
)
new_snapshots.append(new_snap)
coverage = total_filled / max(total_nodes, 1) * 100
print(f" LLM coverage: {total_filled}/{total_nodes} node-quarters ({coverage:.1f}%)")
print(f" Missing nodes filled with per-quarter medians")
return TemporalGraph(
snapshots=new_snapshots,
tickers=graph.tickers,
feature_names=new_feature_names,
)
def main():
parser = argparse.ArgumentParser(description="Augment graph with LLM features")
parser.add_argument("--input", type=str,
default=str(BASE_DIR / "data" / "graphs" / "sp500_2015_2024.pt"))
parser.add_argument("--output", type=str,
default=str(BASE_DIR / "data" / "graphs" / "sp500_2015_2024_llm.pt"))
args = parser.parse_args()
print("Loading LLM features...")
llm_df = load_llm_features()
print(f"\nLoading graph from {args.input}...")
graph = load_graph(args.input)
s = graph.summary()
print(f" Original: {s['quarters']} quarters, {s['tickers']} tickers, {s['features']} features")
print("\nAugmenting graph with LLM features...")
new_graph = augment_graph(graph, llm_df)
ns = new_graph.summary()
print(f" New: {ns['quarters']} quarters, {ns['tickers']} tickers, {ns['features']} features")
print(f" Features: {ns['features']} = {s['features']} original + {len(LLM_FEATURE_COLS)} LLM")
print(f"\nSaving to {args.output}...")
save_graph(new_graph, args.output)
print("Done!")
if __name__ == "__main__":
main()