forked from facebookarchive/ztorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfcomplex.lua
120 lines (97 loc) · 2.85 KB
/
fcomplex.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
--
-- Copyright (c) 2015, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
-- Rudimentary complex number support, building on top of LuaJIT's already
-- existing support (1+2i is parsed correctly and produces a boxed float
-- _Complex value)
--
-- This provides basic mathematical operations and wrappers for the functions
-- in <complex.h>
local ffi = require('ffi')
local M = {}
local mt = {}
local complex
-- Ensure a number is of the complex type
local function tocomplex(a)
if ffi.istype(complex, a) or ffi.istype('complex double', a) then
return a
elseif type(a) == 'number' then
return complex(a, 0)
else
error('Invalid type, numeric type expected')
end
end
M.tocomplex = tocomplex
-- Wrappers around <complex.h> functions
local one_operand = {
'acos', 'acosh', 'arg', 'asin', 'asinh', 'atan', 'atanh', 'cos',
'cosh', 'exp', 'log', 'sin', 'sinh', 'sqrt', 'tan', 'tanh',
}
local two_operand = {
'pow',
}
local one_operand_return_real = { 'abs', 'arg' }
for _, fn in ipairs(one_operand) do
local cname = 'c' .. fn .. 'f'
ffi.cdef(string.format('float _Complex %s(float _Complex)', cname))
M[fn] = function(a)
return ffi.C[cname](tocomplex(a))
end
end
for _, fn in ipairs(two_operand) do
local cname = 'c' .. fn .. 'f'
ffi.cdef(string.format(
'float _Complex %s(float _Complex, float _Complex)', cname))
M[fn] = function(a, b)
return ffi.C[cname](tocomplex(a), tocomplex(b))
end
end
for _, fn in ipairs(one_operand_return_real) do
local cname = 'c' .. fn .. 'f'
ffi.cdef(string.format('float %s(float _Complex)', cname))
M[fn] = function(a)
return ffi.C[cname](tocomplex(a))
end
end
function M.conj(a)
a = tocomplex(a)
return complex(a.re, -a.im)
end
function mt.__add(a, b)
a, b = tocomplex(a), tocomplex(b)
return complex(a.re + b.re, a.im + b.im)
end
function mt.__sub(a, b)
a, b = tocomplex(a), tocomplex(b)
return complex(a.re - b.re, a.im - b.im)
end
function mt.__mul(a, b)
a, b = tocomplex(a), tocomplex(b)
return complex(a.re * b.re - a.im * b.im,
a.re * b.im + a.im * b.re)
end
function mt.__div(a, b)
a, b = tocomplex(a), tocomplex(b)
local d = b.re * b.re + b.im * b.im
return complex((a.re * b.re + a.im * b.im) / d,
(a.im * b.re - a.re * b.im) / d)
end
mt.__pow = M.pow
function mt.__unm(a)
a = tocomplex(a)
return complex(-a.re, -a.im)
end
function mt.__eq(a, b)
if b and a.re == b.re and a.im == b.im then
return true
else
return false
end
end
complex = ffi.metatype('float _Complex', mt)
M.type = complex
return M