@@ -467,6 +467,62 @@ local function getsplotvars(t)
467467 return legend ,x ,y ,z
468468end
469469
470+ local function getscatter3vars (t )
471+ local legend = nil
472+ local x = nil
473+ local y = nil
474+ local z = nil
475+
476+ local function istensor (v )
477+ return type (v ) == ' userdata' and torch .typename (v ):sub (- 6 ) == ' Tensor'
478+ end
479+
480+ local function isstring (v )
481+ return type (v ) == ' string'
482+ end
483+
484+ if # t ~= 3 and # t ~= 4 then
485+ error (' expecting [string,] tensor, tensor, tensor' )
486+ end
487+
488+ if isstring (t [1 ]) then
489+ if # t ~= 4 then
490+ error (' expecting [string,] tensor, tensor, tensor' )
491+ end
492+ for i = 2 , 4 do
493+ if not istensor (t [i ]) then
494+ error (' expecting [string,] tensor, tensor, tensor' )
495+ end
496+ end
497+ legend = t [1 ]
498+ x = t [2 ]
499+ y = t [3 ]
500+ z = t [4 ]
501+ elseif istensor (t [1 ]) then
502+ if # t ~= 3 then
503+ error (' expecting [string,] tensor, tensor, tensor' )
504+ end
505+ for i = 2 , 3 do
506+ if not istensor (t [i ]) then
507+ error (' expecting [string,] tensor, tensor, tensor' )
508+ end
509+ end
510+ x = t [1 ]
511+ y = t [2 ]
512+ z = t [3 ]
513+ legend = ' '
514+ else
515+ error (' expecting [string,] tensor, tensor, tensor' )
516+ end
517+
518+ assert (x :dim () == 1 and y :dim () == 1 and z :dim () == 1 ,
519+ ' x, y and z must be 1D' )
520+ assert (x :isSameSizeAs (y ) and x :isSameSizeAs (z ),
521+ ' x, y and z must be the same size' )
522+
523+ return legend , x , y , z
524+ end
525+
470526local function getimagescvars (t )
471527 local palette = nil
472528 local x = nil
@@ -595,6 +651,35 @@ local function gnu_splot_string(legend,x,y,z)
595651 return hstr ,table.concat (dstr )
596652end
597653
654+ local function gnu_scatter3_string (legend , x , y , z )
655+ local hstr = string.format (' %s\n ' ,' set contour base' )
656+ hstr = string.format (' %s%s\n ' ,hstr ,' set style data points\n ' )
657+ hstr = string.format (' %s%s\n ' ,hstr ,' set hidden3d\n ' )
658+
659+ hstr = hstr .. ' splot '
660+ local dstr = {' ' }
661+ local coef
662+ for i = 1 , # legend do
663+ if i > 1 then hstr = hstr .. ' , ' end
664+ hstr = hstr .. " '-'title '" .. legend [i ] .. " ' " .. ' with points'
665+ end
666+ hstr = hstr .. ' \n '
667+ for i = 1 , # legend do
668+ local xi = x [i ]
669+ local yi = y [i ]
670+ local zi = z [i ]
671+ for j = 1 , xi :size (1 ) do
672+ local xij = xi [j ]
673+ local yij = yi [j ]
674+ local zij = zi [j ]
675+ table.insert (dstr ,
676+ string.format (' %g %g %g\n ' , xij , yij , zij ))
677+ end
678+ table.insert (dstr , ' e\n ' )
679+ end
680+ return hstr , table.concat (dstr )
681+ end
682+
598683local function gnu_imagesc_string (x ,palette )
599684 local hstr = string.format (' %s\n ' ,' set view map' )
600685 hstr = string.format (' %s%s %s\n ' ,hstr ,' set palette' ,palette )
@@ -753,6 +838,11 @@ local function gnusplot(legend,x,y,z)
753838 writeToCurrent (hdr )
754839 writeToCurrent (data )
755840end
841+ local function gnuscatter3 (legend , x , y , z )
842+ local hdr , data = gnu_scatter3_string (legend , x , y , z )
843+ writeToCurrent (hdr )
844+ writeToCurrent (data )
845+ end
756846local function gnuimagesc (x ,palette )
757847 local hdr ,data = gnu_imagesc_string (x ,palette )
758848 writeToCurrent (hdr )
@@ -918,6 +1008,41 @@ function gnuplot.splot(...)
9181008 gnusplot (legends ,xdata ,ydata ,zdata )
9191009end
9201010
1011+ -- scatter3(x, y, z)
1012+ -- scatter3({x1, y1, z1}, {x2, y2, z2})
1013+ -- scatter3({'name1', x1, y1, z1}, {'name2', x2, y2, z2})
1014+ function gnuplot .scatter3 (...)
1015+ local arg = {... }
1016+ if select (' #' , ... ) == 0 then
1017+ error (' no inputs, expecting at least a matrix' )
1018+ end
1019+
1020+ local xdata = {}
1021+ local ydata = {}
1022+ local zdata = {}
1023+ local legends = {}
1024+
1025+ if type (arg [1 ]) == " table" then
1026+ if type (arg [1 ][1 ]) == " table" then
1027+ arg = arg [1 ]
1028+ end
1029+ for i ,v in ipairs (arg ) do
1030+ local l , x , y , z = getscatter3vars (v )
1031+ legends [# legends + 1 ] = l
1032+ xdata [# xdata + 1 ] = x
1033+ ydata [# ydata + 1 ] = y
1034+ zdata [# zdata + 1 ] = z
1035+ end
1036+ else
1037+ local l , x , y , z = getscatter3vars (arg )
1038+ legends [# legends + 1 ] = l
1039+ xdata [# xdata + 1 ] = x
1040+ ydata [# ydata + 1 ] = y
1041+ zdata [# zdata + 1 ] = z
1042+ end
1043+ gnuscatter3 (legends , xdata , ydata , zdata )
1044+ end
1045+
9211046-- imagesc(x) -- x 2D tensor [0 .. 1]
9221047function gnuplot .imagesc (...)
9231048 local arg = {... }
0 commit comments