21
21
from .TftpPacketTypes import *
22
22
from .TftpShared import *
23
23
24
+ if TYPE_CHECKING :
25
+ from .TftpContexts import TftpContext
26
+
27
+
24
28
log = logging .getLogger ("partftpy.TftpStates" )
25
29
26
30
###############################################################################
@@ -35,7 +39,7 @@ def __init__(self, context):
35
39
"""Constructor for setting up common instance variables. The involved
36
40
file object is required, since in tftp there's always a file
37
41
involved."""
38
- self .context = context
42
+ self .context = context # type: TftpContext
39
43
40
44
def handle (self , pkt , raddress , rport ):
41
45
"""An abstract method for handling a packet. It is expected to return
@@ -50,6 +54,9 @@ def handleOACK(self, pkt):
50
54
log .info ("Successful negotiation of options" )
51
55
# Set options to OACK options
52
56
self .context .options = pkt .options
57
+ tsize = pkt .options .get ("tsize" )
58
+ if tsize :
59
+ self .context .metrics .tsize = tsize
53
60
for k , v in self .context .options .items ():
54
61
log .info (" %s = %s" , k , v )
55
62
else :
@@ -112,6 +119,7 @@ def sendDAT(self):
112
119
dat .data = buffer
113
120
dat .blocknumber = blocknumber
114
121
self .context .metrics .bytes += len (dat .data )
122
+ self .context .metrics .packets += 1
115
123
# Testing hook
116
124
if NETWORK_UNRELIABILITY > 0 and random .randrange (NETWORK_UNRELIABILITY ) == 0 :
117
125
log .warning ("Skipping DAT packet %d for testing" , dat .blocknumber )
@@ -122,7 +130,7 @@ def sendDAT(self):
122
130
)
123
131
self .context .metrics .last_dat_time = time .time ()
124
132
if self .context .packethook :
125
- self .context .packethook (dat )
133
+ self .context .packethook (dat , self . context )
126
134
self .context .last_pkt = dat
127
135
return finished
128
136
@@ -175,6 +183,7 @@ def resendLast(self):
175
183
assert self .context .last_pkt is not None
176
184
log .warning ("Resending packet %s on sessions %s" , self .context .last_pkt , self )
177
185
self .context .metrics .resent_bytes += len (self .context .last_pkt .buffer )
186
+ self .context .metrics .resent_packets += 1
178
187
self .context .metrics .add_dup (self .context .last_pkt )
179
188
sendto_port = self .context .tidport
180
189
if not sendto_port :
@@ -186,9 +195,10 @@ def resendLast(self):
186
195
self .context .last_pkt .encode ().buffer , (self .context .host , sendto_port )
187
196
)
188
197
if self .context .packethook :
189
- self .context .packethook (self .context .last_pkt )
198
+ self .context .packethook (self .context .last_pkt , self . context )
190
199
191
200
def handleDat (self , pkt ):
201
+ # type: (TftpPacket) -> TftpState
192
202
"""This method handles a DAT packet during a client download, or a
193
203
server upload."""
194
204
log .debug ("Handling DAT packet - block %d" , pkt .blocknumber )
@@ -202,6 +212,7 @@ def handleDat(self, pkt):
202
212
log .debug ("Writing %d bytes to output file" , len (pkt .data ))
203
213
self .context .fileobj .write (pkt .data )
204
214
self .context .metrics .bytes += len (pkt .data )
215
+ self .context .metrics .packets += 1
205
216
# Check for end-of-file, any less than full data packet.
206
217
if len (pkt .data ) < self .context .options ["blksize" ]:
207
218
log .info ("End of file detected" )
@@ -354,6 +365,7 @@ def handle(self, pkt, raddress, rport):
354
365
tsize = str (self .context .fileobj .tell ())
355
366
self .context .fileobj .seek (0 , 0 )
356
367
self .context .options ["tsize" ] = tsize
368
+ self .context .metrics .tsize = tsize
357
369
358
370
if sendoack :
359
371
# Note, next_block is 0 here since that's the proper
0 commit comments