Skip to content

Commit 117f4b1

Browse files
committed
Merge pull request #20 from mogproject/topic-getch-#15
implement TerminalHandler closes #15
2 parents 9748d59 + be6b1f1 commit 117f4b1

File tree

5 files changed

+303
-1
lines changed

5 files changed

+303
-1
lines changed

src/mog_commons/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.14'
1+
__version__ = '0.1.15'

src/mog_commons/terminal.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
from __future__ import division, print_function, absolute_import, unicode_literals
2+
3+
import os
4+
import sys
5+
import codecs
6+
import subprocess
7+
import locale
8+
import platform
9+
import time
10+
11+
if os.name == 'nt':
12+
# for Windows
13+
import msvcrt
14+
else:
15+
# for Unix/Linux/Mac/CygWin
16+
import termios
17+
import tty
18+
19+
from mog_commons.case_class import CaseClass
20+
from mog_commons.string import to_unicode
21+
22+
__all__ = [
23+
'TerminalHandler',
24+
]
25+
26+
DEFAULT_GETCH_REPEAT_THRESHOLD = 0.3 # in seconds
27+
28+
29+
class TerminalHandler(CaseClass):
30+
"""
31+
IMPORTANT: When you use this class in POSIX environment, make sure to set signal function for restoring terminal
32+
attributes. The function `restore_terminal` is for that purpose. See the example below.
33+
34+
:example:
35+
import signal
36+
37+
t = TerminalHandler()
38+
signal.signal(signal.SIGTERM, t.restore_terminal)
39+
40+
try:
41+
(do your work)
42+
finally:
43+
t.restore_terminal(None, None)
44+
"""
45+
46+
def __init__(self, term_type=None, encoding=None,
47+
stdin=sys.stdin, stdout=sys.stdout, stderr=sys.stderr,
48+
getch_repeat_threshold=DEFAULT_GETCH_REPEAT_THRESHOLD):
49+
CaseClass.__init__(self,
50+
('term_type', term_type or self._detect_term_type()),
51+
('encoding', encoding or self._detect_encoding(stdout)),
52+
('stdin', stdin),
53+
('stdout', stdout),
54+
('stderr', stderr),
55+
('getch_repeat_threshold', getch_repeat_threshold)
56+
)
57+
self.restore_terminal = self._get_restore_function() # binary function for restoring terminal attributes
58+
self.last_getch_time = 0.0
59+
self.last_getch_char = '..'
60+
61+
@staticmethod
62+
def _detect_term_type():
63+
"""
64+
Detect the type of the terminal.
65+
"""
66+
if os.name == 'nt':
67+
if os.environ.get('TERM') == 'xterm':
68+
# maybe MinTTY
69+
return 'mintty'
70+
else:
71+
return 'nt'
72+
if platform.system().upper().startswith('CYGWIN'):
73+
return 'cygwin'
74+
return 'posix'
75+
76+
@staticmethod
77+
def _detect_encoding(stdout):
78+
"""
79+
Detect the default encoding for the terminal's output.
80+
:return: string: encoding string
81+
"""
82+
if stdout.encoding:
83+
return stdout.encoding
84+
85+
if os.environ.get('LANG'):
86+
encoding = os.environ.get('LANG').split('.')[-1]
87+
88+
# validate the encoding string
89+
ret = None
90+
try:
91+
ret = codecs.lookup(encoding)
92+
except LookupError:
93+
pass
94+
if ret:
95+
return encoding
96+
97+
return locale.getpreferredencoding()
98+
99+
def _get_restore_function(self):
100+
"""
101+
Return the binary function for restoring terminal attributes.
102+
:return: function (signal, frame) => None:
103+
"""
104+
if os.name == 'nt':
105+
return lambda signal, frame: None
106+
107+
assert hasattr(self.stdin, 'fileno'), 'Invalid input device.'
108+
fd = self.stdin.fileno()
109+
110+
try:
111+
initial = termios.tcgetattr(fd)
112+
except termios.error:
113+
return lambda signal, frame: None
114+
115+
return lambda signal, frame: termios.tcsetattr(fd, termios.TCSADRAIN, initial)
116+
117+
def clear(self):
118+
"""
119+
Clear the terminal screen.
120+
"""
121+
if self.stdout.isatty() or self.term_type == 'mintty':
122+
cmd, shell = {
123+
'posix': ('clear', False),
124+
'nt': ('cls', True),
125+
'cygwin': (['echo', '-en', r'\ec'], False),
126+
'mintty': (r'echo -en "\ec', False),
127+
}[self.term_type]
128+
subprocess.call(cmd, shell=shell, stdin=self.stdin, stdout=self.stdout, stderr=self.stderr)
129+
130+
def clear_input_buffer(self):
131+
"""
132+
Clear the input buffer.
133+
"""
134+
if self.stdin.isatty():
135+
if os.name == 'nt':
136+
while msvcrt.kbhit():
137+
msvcrt.getch()
138+
else:
139+
try:
140+
self.stdin.seek(0, 2) # may fail in some unseekable file object
141+
except IOError:
142+
pass
143+
144+
def getch(self):
145+
"""
146+
Read one character from stdin.
147+
148+
If stdin is not a tty, read input as one line.
149+
:return: unicode:
150+
"""
151+
ch = self._get_one_char()
152+
self.clear_input_buffer()
153+
154+
try:
155+
# accept only unicode characters (for Python 2)
156+
uch = to_unicode(ch, 'ascii')
157+
except UnicodeError:
158+
return ''
159+
160+
return uch if self._check_key_repeat(uch) else ''
161+
162+
def _get_one_char(self):
163+
if not self.stdin.isatty(): # pipeline or MinTTY
164+
return self.gets()[:1]
165+
elif os.name == 'nt': # Windows
166+
return msvcrt.getch()
167+
else: # POSIX
168+
try:
169+
tty.setraw(self.stdin.fileno())
170+
return self.stdin.read(1)
171+
finally:
172+
self.restore_terminal(None, None)
173+
174+
def _check_key_repeat(self, ch):
175+
if self.getch_repeat_threshold <= 0.0:
176+
return True
177+
178+
t = time.time()
179+
if ch == self.last_getch_char and t < self.last_getch_time + self.getch_repeat_threshold:
180+
return False
181+
182+
self.last_getch_time = t
183+
self.last_getch_char = ch
184+
return True
185+
186+
def gets(self):
187+
"""
188+
Read line from stdin.
189+
190+
The trailing newline will be omitted.
191+
:return: string:
192+
"""
193+
ret = self.stdin.readline()
194+
if ret == '':
195+
raise EOFError # To break out of EOF loop
196+
return ret.rstrip('\n')

src/mog_commons/unittest.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@
1515

1616
from mog_commons.string import to_bytes, to_str
1717

18+
__all__ = [
19+
'FakeInput',
20+
'FakeBytesInput',
21+
'TestCase',
22+
]
23+
1824

1925
class StringBuffer(object):
2026
"""
@@ -39,6 +45,32 @@ def getvalue(self, encoding='utf-8', errors='strict'):
3945
return self._buffer.decode(encoding, errors)
4046

4147

48+
class FakeInput(six.StringIO):
49+
"""Fake input object"""
50+
51+
def __init__(self, buff=None):
52+
six.StringIO.__init__(self, buff or '')
53+
54+
def fileno(self):
55+
return 0
56+
57+
def isatty(self):
58+
return True
59+
60+
61+
class FakeBytesInput(six.BytesIO):
62+
"""Fake bytes input object"""
63+
64+
def __init__(self, buff=None):
65+
six.BytesIO.__init__(self, buff or b'')
66+
67+
def fileno(self):
68+
return 0
69+
70+
def isatty(self):
71+
return True
72+
73+
4274
class TestCase(base_unittest.TestCase):
4375
def assertRaisesRegexp(self, expected_exception, expected_regexp, callable_obj=None, *args, **kwargs):
4476
"""

tests/mog_commons/test_terminal.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# -*- coding: utf-8 -*-
2+
from __future__ import division, print_function, absolute_import, unicode_literals
3+
4+
import os
5+
import time
6+
from mog_commons.terminal import TerminalHandler
7+
from mog_commons.unittest import TestCase, base_unittest, FakeBytesInput
8+
9+
10+
class TestTerminal(TestCase):
11+
def test_getch_from_file(self):
12+
with open(os.path.join('tests', 'resources', 'test_terminal_input.txt')) as f:
13+
t = TerminalHandler(stdin=f)
14+
self.assertEqual(t.getch(), 'a')
15+
self.assertRaises(EOFError, t.getch)
16+
17+
@base_unittest.skipUnless(os.name != 'nt', 'requires POSIX compatible')
18+
def test_getch(self):
19+
self.assertEqual(TerminalHandler(stdin=FakeBytesInput(b'')).getch(), '')
20+
self.assertEqual(TerminalHandler(stdin=FakeBytesInput(b'\x03')).getch(), '\x03')
21+
self.assertEqual(TerminalHandler(stdin=FakeBytesInput(b'abc')).getch(), 'a')
22+
self.assertEqual(TerminalHandler(stdin=FakeBytesInput('あ'.encode('utf-8'))).getch(), '')
23+
self.assertEqual(TerminalHandler(stdin=FakeBytesInput('あ'.encode('sjis'))).getch(), '')
24+
25+
@base_unittest.skipUnless(os.name != 'nt', 'requires POSIX compatible')
26+
def test_getch_key_repeat(self):
27+
fin = FakeBytesInput(b'abcde')
28+
29+
def append_char(ch):
30+
fin.write(ch)
31+
fin.seek(-len(ch), 1)
32+
33+
t1 = TerminalHandler(stdin=fin)
34+
self.assertEqual(t1.getch(), 'a')
35+
append_char(b'x')
36+
self.assertEqual(t1.getch(), 'x')
37+
append_char(b'x')
38+
self.assertEqual(t1.getch(), '')
39+
append_char(b'x')
40+
self.assertEqual(t1.getch(), '')
41+
append_char(b'y')
42+
self.assertEqual(t1.getch(), 'y')
43+
append_char(b'y')
44+
self.assertEqual(t1.getch(), '')
45+
46+
time.sleep(1)
47+
append_char(b'y')
48+
self.assertEqual(t1.getch(), 'y')
49+
50+
@base_unittest.skipUnless(os.name != 'nt', 'requires POSIX compatible')
51+
def test_getch_key_repeat_disabled(self):
52+
fin = FakeBytesInput(b'abcde')
53+
54+
def append_char(ch):
55+
fin.write(ch)
56+
fin.seek(-len(ch), 1)
57+
58+
t1 = TerminalHandler(stdin=fin, getch_repeat_threshold=0)
59+
self.assertEqual(t1.getch(), 'a')
60+
append_char(b'x')
61+
self.assertEqual(t1.getch(), 'x')
62+
append_char(b'x')
63+
self.assertEqual(t1.getch(), 'x')
64+
append_char(b'x')
65+
self.assertEqual(t1.getch(), 'x')
66+
append_char(b'y')
67+
self.assertEqual(t1.getch(), 'y')
68+
append_char(b'y')
69+
self.assertEqual(t1.getch(), 'y')
70+
71+
time.sleep(1)
72+
append_char(b'y')
73+
self.assertEqual(t1.getch(), 'y')
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
abcde

0 commit comments

Comments
 (0)