@@ -6,6 +6,15 @@ import {FlattenedNode} from './shapes/nodeShapes';
66import TreeState , { State } from './state/TreeState' ;
77
88export default class Tree extends React . Component {
9+ constructor ( props ) {
10+ super ( props ) ;
11+ this . state = {
12+ stickyHeaders : [ ] , // To track all visible group headers
13+ topStickyHeader : null , // The header that should be sticky
14+ } ;
15+ this . _listRef = React . createRef ( ) ;
16+ }
17+
918 _cache = new CellMeasurerCache ( {
1019 fixedWidth : true ,
1120 minHeight : 20 ,
@@ -35,8 +44,101 @@ export default class Tree extends React.Component {
3544 : nodes [ index ] ;
3645 } ;
3746
47+ // Determine if a node is a group header
48+ isGroupHeader = node => {
49+ // Group headers are typically parent nodes with children
50+ // and deepness of 0 (root level)
51+ return node . children && node . children . length > 0 && node . deepness === 0 ;
52+ } ;
53+
54+ componentDidMount ( ) {
55+ // Initial check for headers after mounting
56+ if ( this . _listRef . current ) {
57+ const list = this . _listRef . current ;
58+ const grid = list && list . Grid ;
59+ if ( grid ) {
60+ this . handleScroll ( {
61+ scrollTop : grid . state . scrollTop ,
62+ scrollHeight : grid . state . scrollHeight ,
63+ clientHeight : grid . props . height ,
64+ } ) ;
65+ }
66+ }
67+ }
68+
69+ // Get all headers in the current data
70+ getAllHeaders = ( ) => {
71+ const rowCount = this . getRowCount ( ) ;
72+ const headers = [ ] ;
73+
74+ for ( let i = 0 ; i < rowCount ; i ++ ) {
75+ const node = this . getNode ( i ) ;
76+ if ( this . isGroupHeader ( node ) ) {
77+ // Calculate the position by summing heights of all rows before this one
78+ let top = 0 ;
79+ for ( let j = 0 ; j < i ; j ++ ) {
80+ top += this . _cache . rowHeight ( { index : j } ) ;
81+ }
82+
83+ headers . push ( {
84+ node,
85+ index : i ,
86+ top,
87+ } ) ;
88+ }
89+ }
90+
91+ return headers ;
92+ } ;
93+
94+ // Handle scroll events to update sticky headers
95+ handleScroll = ( { scrollTop, scrollHeight, clientHeight} ) => {
96+ if ( ! this . _listRef . current ) return ;
97+
98+ // Get all headers in the tree
99+ const allHeaders = this . getAllHeaders ( ) ;
100+
101+ // Find headers that should be visible based on scroll position
102+ const visibleHeaders = allHeaders . filter ( header => {
103+ // Calculate the bottom position of this header row
104+ const headerHeight = this . _cache . rowHeight ( { index : header . index } ) ;
105+ const headerBottom = header . top + headerHeight ;
106+
107+ // Header is visible if:
108+ // 1. Its top is between scrollTop and scrollTop + clientHeight, OR
109+ // 2. Its bottom is between scrollTop and scrollTop + clientHeight, OR
110+ // 3. It starts before scrollTop and ends after scrollTop + clientHeight
111+ return (
112+ ( header . top >= scrollTop && header . top <= scrollTop + clientHeight ) ||
113+ ( headerBottom >= scrollTop && headerBottom <= scrollTop + clientHeight ) ||
114+ ( header . top <= scrollTop && headerBottom >= scrollTop + clientHeight )
115+ ) ;
116+ } ) ;
117+
118+ // Find the header that should be sticky
119+ // It's the last header whose top position is less than or equal to scrollTop
120+ const headersBeforeViewport = allHeaders . filter ( h => h . top <= scrollTop ) ;
121+ const topStickyHeader =
122+ headersBeforeViewport . length > 0 ? headersBeforeViewport [ headersBeforeViewport . length - 1 ] : null ;
123+
124+ // Only update state if something has changed
125+ const currentStickyId = this . state . topStickyHeader && this . state . topStickyHeader . node && this . state . topStickyHeader . node . id ;
126+ const newStickyId = topStickyHeader && topStickyHeader . node && topStickyHeader . node . id ;
127+
128+ if ( currentStickyId !== newStickyId || this . state . stickyHeaders . length !== visibleHeaders . length ) {
129+ this . setState ( {
130+ stickyHeaders : visibleHeaders ,
131+ topStickyHeader,
132+ } ) ;
133+ }
134+ } ;
135+
38136 rowRenderer = ( { node, key, measure, style, NodeRenderer, index} ) => {
39137 const { nodeMarginLeft} = this . props ;
138+ const isHeader = this . isGroupHeader ( node ) ;
139+
140+ // Add a class to identify group headers
141+ const className = isHeader ? 'tree-group-header' : '' ;
40142
41143 return (
42144 < NodeRenderer
@@ -47,14 +149,49 @@ export default class Tree extends React.Component {
47149 userSelect : 'none' ,
48150 cursor : 'pointer' ,
49151 } }
152+ className = { className }
50153 node = { node }
51154 onChange = { this . props . onChange }
52155 measure = { measure }
53156 index = { index }
157+ isGroupHeader = { isHeader }
54158 />
55159 ) ;
56160 } ;
57161
162+ // Render the sticky header
163+ renderStickyHeader = ( ) => {
164+ const { topStickyHeader} = this . state ;
165+ if ( ! topStickyHeader ) return null ;
166+
167+ const { NodeRenderer, nodeMarginLeft} = this . props ;
168+ // Always use the current node from the tree to ensure we have the latest state
169+ const index = topStickyHeader . index ;
170+ const currentNode = this . getNode ( index ) ;
171+
172+ return (
173+ < div className = "tree-sticky-header" >
174+ < NodeRenderer
175+ key = { `sticky-header-${ currentNode . id } ` }
176+ style = { {
177+ marginLeft : currentNode . deepness * nodeMarginLeft ,
178+ userSelect : 'none' ,
179+ cursor : 'pointer' ,
180+ width : '100%' ,
181+ background : '#fff' , // Background to ensure visibility
182+ zIndex : 10 ,
183+ } }
184+ className = "tree-group-header tree-sticky"
185+ node = { currentNode }
186+ onChange = { this . props . onChange }
187+ index = { index }
188+ isGroupHeader = { true }
189+ isSticky = { true }
190+ />
191+ </ div >
192+ ) ;
193+ } ;
194+
58195 measureRowRenderer = nodes => ( { key, index, style, parent} ) => {
59196 const { NodeRenderer} = this . props ;
60197 const node = this . getNode ( index ) ;
@@ -66,25 +203,64 @@ export default class Tree extends React.Component {
66203 ) ;
67204 } ;
68205
206+ componentDidUpdate ( prevProps ) {
207+ // If nodes change, reset the cache
208+ if ( prevProps . nodes !== this . props . nodes ) {
209+ this . _cache . clearAll ( ) ;
210+ if ( this . _listRef . current ) {
211+ this . _listRef . current . recomputeRowHeights ( ) ;
212+ }
213+
214+ // Force rerender of sticky header when nodes change
215+ this . forceUpdate ( ) ;
216+ }
217+ }
218+
69219 render ( ) {
70220 const { nodes, width, scrollToIndex, scrollToAlignment} = this . props ;
221+ const { topStickyHeader} = this . state ;
222+
223+ // Calculate the height of the sticky header to properly offset the list
224+ const stickyHeaderHeight = topStickyHeader ? this . _cache . rowHeight ( { index : topStickyHeader . index } ) : 0 ;
71225
72226 return (
73- < AutoSizer disableWidth = { Boolean ( width ) } >
74- { ( { height, width : autoWidth } ) => (
75- < List
76- deferredMeasurementCache = { this . _cache }
77- ref = { r => ( this . _list = r ) }
78- height = { height }
79- rowCount = { this . getRowCount ( ) }
80- rowHeight = { this . _cache . rowHeight }
81- rowRenderer = { this . measureRowRenderer ( nodes ) }
82- width = { width || autoWidth }
83- scrollToIndex = { scrollToIndex }
84- scrollToAlignment = { scrollToAlignment }
85- />
227+ < div className = "tree-container" style = { { position : 'relative' , height : '100%' } } >
228+ { /* Sticky header container */ }
229+ { topStickyHeader && (
230+ < div
231+ className = "tree-sticky-header-container"
232+ style = { {
233+ position : 'absolute' ,
234+ top : 0 ,
235+ left : 0 ,
236+ right : 0 ,
237+ zIndex : 100 ,
238+ height : `${ stickyHeaderHeight } px` ,
239+ } }
240+ >
241+ { this . renderStickyHeader ( ) }
242+ </ div >
86243 ) }
87- </ AutoSizer >
244+
245+ < AutoSizer disableWidth = { Boolean ( width ) } >
246+ { ( { height, width : autoWidth } ) => (
247+ < List
248+ deferredMeasurementCache = { this . _cache }
249+ ref = { this . _listRef }
250+ height = { height }
251+ rowCount = { this . getRowCount ( ) }
252+ rowHeight = { this . _cache . rowHeight }
253+ rowRenderer = { this . measureRowRenderer ( nodes ) }
254+ width = { width || autoWidth }
255+ scrollToIndex = { scrollToIndex }
256+ scrollToAlignment = { scrollToAlignment }
257+ onScroll = { this . handleScroll }
258+ // Important: adds overscan to ensure we load enough rows to find headers
259+ overscanRowCount = { 20 }
260+ />
261+ ) }
262+ </ AutoSizer >
263+ </ div >
88264 ) ;
89265 }
90266}
0 commit comments