1+ // Licensed to the .NET Foundation under one or more agreements.
2+ // The .NET Foundation licenses this file to you under the MIT license.
3+ // See the LICENSE file in the project root for more information.
4+
5+ using System ;
6+ using System . Runtime . CompilerServices ;
7+ using System . Runtime . Intrinsics ;
8+ using System . Runtime . Intrinsics . X86 ;
9+ using Microsoft . ML . Internal . CpuMath . Core ;
10+
11+ namespace Microsoft . ML . Internal . CpuMath . FactorizationMachine
12+ {
13+ internal static class AvxIntrinsics
14+ {
15+ private static readonly Vector256 < float > _point5 = Vector256 . Create ( 0.5f ) ;
16+
17+ [ MethodImplAttribute ( MethodImplOptions . AggressiveInlining ) ]
18+ private static Vector256 < float > MultiplyAdd ( Vector256 < float > src1 , Vector256 < float > src2 , Vector256 < float > src3 )
19+ {
20+ if ( Fma . IsSupported )
21+ {
22+ return Fma . MultiplyAdd ( src1 , src2 , src3 ) ;
23+ }
24+ else
25+ {
26+ Vector256 < float > product = Avx . Multiply ( src1 , src2 ) ;
27+ return Avx . Add ( product , src3 ) ;
28+ }
29+ }
30+
31+ [ MethodImplAttribute ( MethodImplOptions . AggressiveInlining ) ]
32+ private static Vector256 < float > MultiplyAddNegated ( Vector256 < float > src1 , Vector256 < float > src2 , Vector256 < float > src3 )
33+ {
34+ if ( Fma . IsSupported )
35+ {
36+ return Fma . MultiplyAddNegated ( src1 , src2 , src3 ) ;
37+ }
38+ else
39+ {
40+ Vector256 < float > product = Avx . Multiply ( src1 , src2 ) ;
41+ return Avx . Subtract ( src3 , product ) ;
42+ }
43+ }
44+
45+ // This function implements Algorithm 1 in https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf.
46+ // Compute the output value of the field-aware factorization, as the sum of the linear part and the latent part.
47+ // The linear part is the inner product of linearWeights and featureValues.
48+ // The latent part is the sum of all intra-field interactions in one field f, for all fields possible
49+ public static unsafe void CalculateIntermediateVariables ( int * fieldIndices , int * featureIndices , float * featureValues ,
50+ float * linearWeights , float * latentWeights , float * latentSum , float * response , int fieldCount , int latentDim , int count )
51+ {
52+ Contracts . Assert ( Avx . IsSupported ) ;
53+
54+ // The number of all possible fields.
55+ int m = fieldCount ;
56+ int d = latentDim ;
57+ int c = count ;
58+ int * pf = fieldIndices ;
59+ int * pi = featureIndices ;
60+ float * px = featureValues ;
61+ float * pw = linearWeights ;
62+ float * pv = latentWeights ;
63+ float * pq = latentSum ;
64+ float linearResponse = 0 ;
65+ float latentResponse = 0 ;
66+
67+ Unsafe . InitBlock ( pq , 0 , ( uint ) ( m * m * d * sizeof ( float ) ) ) ;
68+
69+ Vector256 < float > y = Vector256 < float > . Zero ;
70+ Vector256 < float > tmp = Vector256 < float > . Zero ;
71+
72+ for ( int i = 0 ; i < c ; i ++ )
73+ {
74+ int f = pf [ i ] ;
75+ int j = pi [ i ] ;
76+ linearResponse += pw [ j ] * px [ i ] ;
77+
78+ Vector256 < float > x = Avx . BroadcastScalarToVector256 ( px + i ) ;
79+ Vector256 < float > xx = Avx . Multiply ( x , x ) ;
80+
81+ // tmp -= <v_j,f, v_j,f> * x * x
82+ int vBias = j * m * d + f * d ;
83+
84+ // j-th feature's latent vector in the f-th field hidden space.
85+ float * vjf = pv + vBias ;
86+
87+ for ( int k = 0 ; k + 8 <= d ; k += 8 )
88+ {
89+ Vector256 < float > vjfBuffer = Avx . LoadVector256 ( vjf + k ) ;
90+ tmp = MultiplyAddNegated ( Avx . Multiply ( vjfBuffer , vjfBuffer ) , xx , tmp ) ;
91+ }
92+
93+ for ( int fprime = 0 ; fprime < m ; fprime ++ )
94+ {
95+ vBias = j * m * d + fprime * d ;
96+ int qBias = f * m * d + fprime * d ;
97+ float * vjfprime = pv + vBias ;
98+ float * qffprime = pq + qBias ;
99+
100+ // q_f,f' += v_j,f' * x
101+ for ( int k = 0 ; k + 8 <= d ; k += 8 )
102+ {
103+ Vector256 < float > vjfprimeBuffer = Avx . LoadVector256 ( vjfprime + k ) ;
104+ Vector256 < float > q = Avx . LoadVector256 ( qffprime + k ) ;
105+ q = MultiplyAdd ( vjfprimeBuffer , x , q ) ;
106+ Avx . Store ( qffprime + k , q ) ;
107+ }
108+ }
109+ }
110+
111+ for ( int f = 0 ; f < m ; f ++ )
112+ {
113+ // tmp += <q_f,f, q_f,f>
114+ float * qff = pq + f * m * d + f * d ;
115+ for ( int k = 0 ; k + 8 <= d ; k += 8 )
116+ {
117+ Vector256 < float > qffBuffer = Avx . LoadVector256 ( qff + k ) ;
118+
119+ // Intra-field interactions.
120+ tmp = MultiplyAdd ( qffBuffer , qffBuffer , tmp ) ;
121+ }
122+
123+ // y += <q_f,f', q_f',f>, f != f'
124+ // Whis loop handles inter - field interactions because f != f'.
125+ for ( int fprime = f + 1 ; fprime < m ; fprime ++ )
126+ {
127+ float * qffprime = pq + f * m * d + fprime * d ;
128+ float * qfprimef = pq + fprime * m * d + f * d ;
129+ for ( int k = 0 ; k + 8 <= d ; k += 8 )
130+ {
131+ // Inter-field interaction.
132+ Vector256 < float > qffprimeBuffer = Avx . LoadVector256 ( qffprime + k ) ;
133+ Vector256 < float > qfprimefBuffer = Avx . LoadVector256 ( qfprimef + k ) ;
134+ y = MultiplyAdd ( qffprimeBuffer , qfprimefBuffer , y ) ;
135+ }
136+ }
137+ }
138+
139+ y = MultiplyAdd ( _point5 , tmp , y ) ;
140+ tmp = Avx . Add ( y , Avx . Permute2x128 ( y , y , 1 ) ) ;
141+ tmp = Avx . HorizontalAdd ( tmp , tmp ) ;
142+ y = Avx . HorizontalAdd ( tmp , tmp ) ;
143+ Sse . StoreScalar ( & latentResponse , y . GetLower ( ) ) ; // The lowest slot is the response value.
144+ * response = linearResponse + latentResponse ;
145+ }
146+
147+ // This function implements Algorithm 2 in https://github.com/wschin/fast-ffm/blob/master/fast-ffm.pdf
148+ // Calculate the stochastic gradient and update the model.
149+ public static unsafe void CalculateGradientAndUpdate ( int * fieldIndices , int * featureIndices , float * featureValues , float * latentSum , float * linearWeights ,
150+ float * latentWeights , float * linearAccumulatedSquaredGrads , float * latentAccumulatedSquaredGrads , float lambdaLinear , float lambdaLatent , float learningRate ,
151+ int fieldCount , int latentDim , float weight , int count , float slope )
152+ {
153+ Contracts . Assert ( Avx . IsSupported ) ;
154+
155+ int m = fieldCount ;
156+ int d = latentDim ;
157+ int c = count ;
158+ int * pf = fieldIndices ;
159+ int * pi = featureIndices ;
160+ float * px = featureValues ;
161+ float * pq = latentSum ;
162+ float * pw = linearWeights ;
163+ float * pv = latentWeights ;
164+ float * phw = linearAccumulatedSquaredGrads ;
165+ float * phv = latentAccumulatedSquaredGrads ;
166+
167+ Vector256 < float > wei = Vector256 . Create ( weight ) ;
168+ Vector256 < float > s = Vector256 . Create ( slope ) ;
169+ Vector256 < float > lr = Vector256 . Create ( learningRate ) ;
170+ Vector256 < float > lambdav = Vector256 . Create ( lambdaLatent ) ;
171+
172+ for ( int i = 0 ; i < count ; i ++ )
173+ {
174+ int f = pf [ i ] ;
175+ int j = pi [ i ] ;
176+
177+ // Calculate gradient of linear term w_j.
178+ float g = weight * ( lambdaLinear * pw [ j ] + slope * px [ i ] ) ;
179+
180+ // Accumulate the gradient of the linear term.
181+ phw [ j ] += g * g ;
182+
183+ // Perform ADAGRAD update rule to adjust linear term.
184+ pw [ j ] -= learningRate / MathF . Sqrt ( phw [ j ] ) * g ;
185+
186+ // Update latent term, v_j,f', f'=1,...,m.
187+ Vector256 < float > x = Avx . BroadcastScalarToVector256 ( px + i ) ;
188+
189+ for ( int fprime = 0 ; fprime < m ; fprime ++ )
190+ {
191+ float * vjfprime = pv + j * m * d + fprime * d ;
192+ float * hvjfprime = phv + j * m * d + fprime * d ;
193+ float * qfprimef = pq + fprime * m * d + f * d ;
194+ Vector256 < float > sx = Avx . Multiply ( s , x ) ;
195+
196+ for ( int k = 0 ; k + 8 <= d ; k += 8 )
197+ {
198+ Vector256 < float > v = Avx . LoadVector256 ( vjfprime + k ) ;
199+ Vector256 < float > q = Avx . LoadVector256 ( qfprimef + k ) ;
200+
201+ // Calculate L2-norm regularization's gradient.
202+ Vector256 < float > gLatent = Avx . Multiply ( lambdav , v ) ;
203+
204+ Vector256 < float > tmp = q ;
205+
206+ // Calculate loss function's gradient.
207+ if ( fprime == f )
208+ tmp = MultiplyAddNegated ( v , x , q ) ;
209+ gLatent = MultiplyAdd ( sx , tmp , gLatent ) ;
210+ gLatent = Avx . Multiply ( wei , gLatent ) ;
211+
212+ // Accumulate the gradient of latent vectors.
213+ Vector256 < float > h = MultiplyAdd ( gLatent , gLatent , Avx . LoadVector256 ( hvjfprime + k ) ) ;
214+
215+ // Perform ADAGRAD update rule to adjust latent vector.
216+ v = MultiplyAddNegated ( lr , Avx . Multiply ( Avx . ReciprocalSqrt ( h ) , gLatent ) , v ) ;
217+ Avx . Store ( vjfprime + k , v ) ;
218+ Avx . Store ( hvjfprime + k , h ) ;
219+ }
220+ }
221+ }
222+ }
223+ }
224+ }
0 commit comments