Skip to content

Commit

Permalink
Merge pull request #10 from cesbit/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
joente committed Sep 18, 2024
2 parents 39cc9d1 + 4aecb6f commit f584a98
Show file tree
Hide file tree
Showing 16 changed files with 270 additions and 155 deletions.
19 changes: 13 additions & 6 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,27 @@ on:
jobs:
build:
runs-on: ubuntu-latest
strategy:
max-parallel: 4
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.9
uses: actions/setup-python@v4
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: 3.9
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install pytest pycodestyle
pip install pytest pycodestyle pyright
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
# - name: Run tests with pytest
# run: |
# pytest
- name: Lint with PyCodeStyle
run: |
find . -name \*.py -exec pycodestyle {} +
find . -name \*.py -exec pycodestyle {} +
- name: Type checking with PyRight
run: |
pyright
4 changes: 3 additions & 1 deletion asyncsnmplib/asn1.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class Class(enum.IntEnum):
TNumber = Union[Number, int]
TType = Union[Type, int]
TClass = Union[Class, int]
TOid = Tuple[int, ...]
TValue = Any


class Tag(NamedTuple):
Expand Down Expand Up @@ -546,7 +548,7 @@ def _decode_null(bytes_data: bytes) -> None:
raise Error("ASN1 syntax error")

@staticmethod
def _decode_object_identifier(bytes_data: bytes) -> tuple:
def _decode_object_identifier(bytes_data: bytes) -> TOid:
result: List[int] = []
value: int = 0
for i in range(len(bytes_data)):
Expand Down
86 changes: 43 additions & 43 deletions asyncsnmplib/client.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,32 @@
import asyncio
from typing import Iterable, Optional, Tuple, List, Type
from .exceptions import (
SnmpNoConnection,
SnmpErrorNoSuchName,
SnmpTooMuchRows,
SnmpNoAuthParams,
)
from .asn1 import Tag, TOid, TValue
from .package import SnmpMessage
from .pdu import SnmpGet, SnmpGetNext, SnmpGetBulk
from .protocol import SnmpProtocol
from .v3.auth import AUTH_PROTO
from .v3.encr import PRIV_PROTO
from .v3.auth import Auth
from .v3.encr import Priv
from .v3.package import SnmpV3Message
from .v3.protocol import SnmpV3Protocol


class Snmp:
version = 1 # = v2

def __init__(self, host, port=161, community='public', max_rows=10000):
self._loop = asyncio.get_event_loop()
def __init__(
self,
host: str,
port: int = 161,
community: str = 'public',
max_rows: int = 10_000,
loop: Optional[asyncio.AbstractEventLoop] = None):
self._loop = loop if loop else asyncio.get_running_loop()
self._protocol = None
self._transport = None
self.host = host
Expand All @@ -28,7 +36,7 @@ def __init__(self, host, port=161, community='public', max_rows=10000):

# On some systems it seems to be required to set the remote_addr argument
# https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.create_datagram_endpoint
async def connect(self, timeout=10):
async def connect(self, timeout: float = 10.0):
try:
infos = await self._loop.getaddrinfo(self.host, self.port)
family, *_, addr = infos[0]
Expand All @@ -44,7 +52,7 @@ async def connect(self, timeout=10):
self._transport = transport

def _get(self, oids, timeout=None):
if self._transport is None:
if self._protocol is None:
raise SnmpNoConnection
pdu = SnmpGet(0, oids)
message = SnmpMessage.make(self.version, self.community, pdu)
Expand All @@ -54,32 +62,34 @@ def _get(self, oids, timeout=None):
return self._protocol.send(message)

def _get_next(self, oids):
if self._transport is None:
if self._protocol is None:
raise SnmpNoConnection
pdu = SnmpGetNext(0, oids)
message = SnmpMessage.make(self.version, self.community, pdu)
return self._protocol.send(message)

def _get_bulk(self, oids):
if self._transport is None:
if self._protocol is None:
raise SnmpNoConnection
pdu = SnmpGetBulk(0, oids)
message = SnmpMessage.make(self.version, self.community, pdu)
return self._protocol.send(message)

async def get(self, oid, timeout=None):
async def get(self, oid: TOid, timeout: Optional[float] = None
) -> Tuple[TOid, Tag, TValue]:
vbs = await self._get([oid], timeout)
return vbs[0]

async def get_next(self, oid):
async def get_next(self, oid: TOid) -> Tuple[TOid, Tag, TValue]:
vbs = await self._get_next([oid])
return vbs[0]

async def get_next_multi(self, oids):
async def get_next_multi(self, oids: Iterable[TOid]
) -> List[Tuple[TOid, TValue]]:
vbs = await self._get_next(oids)
return [(oid, value) for oid, _, value in vbs if oid[:-1] in oids]

async def walk(self, oid):
async def walk(self, oid: TOid) -> List[Tuple[TOid, TValue]]:
next_oid = oid
prefixlen = len(oid)
rows = []
Expand Down Expand Up @@ -115,7 +125,7 @@ def close(self):
class SnmpV1(Snmp):
version = 0

async def walk(self, oid):
async def walk(self, oid: TOid) -> List[Tuple[TOid, TValue]]:
next_oid = oid
prefixlen = len(oid)
rows = []
Expand Down Expand Up @@ -150,15 +160,14 @@ class SnmpV3(Snmp):

def __init__(
self,
host,
username,
auth_proto='USM_AUTH_NONE',
auth_passwd=None,
priv_proto='USM_PRIV_NONE',
priv_passwd=None,
port=161,
max_rows=10000):
self._loop = asyncio.get_event_loop()
host: str,
username: str,
auth: Optional[Tuple[Type[Auth], str]] = None,
priv: Optional[Tuple[Type[Priv], str]] = None,
port: int = 161,
max_rows: int = 10_000,
loop: Optional[asyncio.AbstractEventLoop] = None):
self._loop = loop if loop else asyncio.get_running_loop()
self._protocol = None
self._transport = None
self.host = host
Expand All @@ -170,28 +179,16 @@ def __init__(
self._auth_hash_localized = None
self._priv_hash = None
self._priv_hash_localized = None
try:
self._auth_proto = AUTH_PROTO[auth_proto]
except KeyError:
raise Exception('Supply valid auth_proto')
try:
self._priv_proto = PRIV_PROTO[priv_proto]
except KeyError:
raise Exception('Supply valid priv_proto')
if self._priv_proto and not self._auth_proto:
raise Exception('Supply auth_proto')
if self._auth_proto:
if auth_passwd is None:
raise Exception('Supply auth_passwd')
if auth is not None:
self._auth_proto, auth_passwd = auth
self._auth_hash = self._auth_proto.hash_passphrase(auth_passwd)
if self._priv_proto:
if priv_passwd is None:
raise Exception('Supply priv_passwd')
self._priv_hash = self._auth_proto.hash_passphrase(priv_passwd)
if priv is not None:
self._priv_proto, priv_passwd = priv
self._priv_hash = self._auth_proto.hash_passphrase(priv_passwd)

# On some systems it seems to be required to set the remote_addr argument
# https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.create_datagram_endpoint
async def connect(self, timeout=10):
async def connect(self, timeout: float = 10.0):
try:
infos = await self._loop.getaddrinfo(self.host, self.port)
family, *_, addr = infos[0]
Expand All @@ -211,6 +208,9 @@ async def connect(self, timeout=10):
raise SnmpNoAuthParams

async def _get_auth_params(self, timeout=10):
# TODO for long requests this will need to be refreshed
# https://datatracker.ietf.org/doc/html/rfc3414#section-2.2.3
assert self._protocol is not None
pdu = SnmpGet(0, [])
message = SnmpV3Message.make(pdu, [b'', 0, 0, b'', b'', b''])
# this request will not retry like the other requests
Expand All @@ -225,7 +225,7 @@ async def _get_auth_params(self, timeout=10):
if self._priv_proto else None

def _get(self, oids, timeout=None):
if self._transport is None:
if self._protocol is None:
raise SnmpNoConnection
elif self._auth_params is None:
raise SnmpNoAuthParams
Expand All @@ -248,7 +248,7 @@ def _get(self, oids, timeout=None):
self._priv_hash_localized)

def _get_next(self, oids):
if self._transport is None:
if self._protocol is None:
raise SnmpNoConnection
elif self._auth_params is None:
raise SnmpNoAuthParams
Expand All @@ -262,7 +262,7 @@ def _get_next(self, oids):
self._priv_hash_localized)

def _get_bulk(self, oids):
if self._transport is None:
if self._protocol is None:
raise SnmpNoConnection
elif self._auth_params is None:
raise SnmpNoAuthParams
Expand Down
1 change: 0 additions & 1 deletion asyncsnmplib/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
__all__ = (
"SnmpTimeoutError",
"SnmpUnsupportedValueType",
"SnmpErrorTooBig",
"SnmpErrorNoSuchName",
"SnmpErrorBadValue",
Expand Down
4 changes: 4 additions & 0 deletions asyncsnmplib/mib/mib.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def on_mib(mi: dict, mibname: str, mib: dict, lk_definitions: dict):
obj['syntax'] = obj['syntax']['syntax']

lk_definitions[name] = obj
elif obj['tp'] == 'TRAP-TYPE':
lk_definitions[name] = obj

for name, obj in mib.items():
if 'value' in obj:
Expand All @@ -90,6 +92,8 @@ def on_mib(mi: dict, mibname: str, mib: dict, lk_definitions: dict):
names[name] = obj
elif obj['tp'] == 'OBJECT-GROUP':
names[name] = obj
elif obj['tp'] == 'NOTIFICATION-TYPE':
names[name] = obj

for name, obj in names.items():
other_name = name
Expand Down
21 changes: 14 additions & 7 deletions asyncsnmplib/mib/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Tuple, Union
from typing import Tuple, Union, List
from ..asn1 import TOid, TValue
from .mib_index import MIB_INDEX
from .syntax_funs import SYNTAX_FUNS

Expand All @@ -8,7 +9,7 @@
FLAGS_SEPERATOR = ','


def on_octet_string(value: bytes) -> str:
def on_octet_string(value: TValue) -> Union[str, None]:
"""
used as a fallback for OCTET STRING when no formatter is found/defined
"""
Expand All @@ -18,13 +19,13 @@ def on_octet_string(value: bytes) -> str:
return


def on_integer(value: int) -> str:
def on_integer(value: TValue) -> Union[int, None]:
if not isinstance(value, int):
return
return value


def on_oid_map(oid: Tuple[int]) -> str:
def on_oid_map(oid: TValue) -> Union[str, None]:
if not isinstance(oid, tuple):
# some devices don't follow mib's syntax
# for example ipAddressTable.ipAddressPrefix returns an int in case of
Expand All @@ -45,7 +46,7 @@ def on_value_map_b(value: bytes, map_: dict) -> str:
v for k, v in map_.items() if value[k // 8] & (1 << k % 8))


def on_syntax(syntax: dict, value: Union[int, str, bytes]):
def on_syntax(syntax: dict, value: TValue):
"""
this is point where bytes are converted to right datatype
"""
Expand All @@ -65,7 +66,10 @@ def on_syntax(syntax: dict, value: Union[int, str, bytes]):
raise Exception(f'Invalid syntax {syntax}')


def on_result(base_oid: Tuple[int], result: dict) -> Tuple[str, list]:
def on_result(
base_oid: TOid,
result: List[Tuple[TOid, TValue]],
) -> Tuple[str, List[dict]]:
"""returns a more compat result (w/o prefixes) and groups formatted
metrics by base_oid
"""
Expand Down Expand Up @@ -109,7 +113,10 @@ def on_result(base_oid: Tuple[int], result: dict) -> Tuple[str, list]:
return result_name, list(table.values())


def on_result_base(base_oid: Tuple[int], result: dict) -> Tuple[str, list]:
def on_result_base(
base_oid: TOid,
result: List[Tuple[TOid, TValue]],
) -> Tuple[str, List[dict]]:
"""returns formatted metrics grouped by base_oid
"""
base = MIB_INDEX[base_oid]
Expand Down
12 changes: 7 additions & 5 deletions asyncsnmplib/package.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .asn1 import Decoder, Encoder, Number
from typing import Optional, Tuple, List
from .asn1 import Decoder, Encoder, Number, Tag, TOid, TValue


class Package:
Expand All @@ -8,12 +9,13 @@ class Package:
pdu = None

def __init__(self):
self.request_id = None
self.error_status = None
self.error_index = None
self.variable_bindings = []
self.request_id: Optional[int] = None
self.error_status: Optional[int] = None
self.error_index: Optional[int] = None
self.variable_bindings: List[Tuple[TOid, Tag, TValue]] = []

def encode(self):
assert self.pdu is not None
encoder = Encoder()

with encoder.enter(Number.Sequence):
Expand Down
Loading

0 comments on commit f584a98

Please sign in to comment.