Skip to content

Commit 81f4209

Browse files
benshukbenshuk
andauthored
fix: ♻️ expose poll_for_status function (#291)
* fix: ♻️ expose `poll_for_status` function * refactor: ♻️ update `poll_for_status` method parameters to use consistent naming convention * chore: 🔧 update CI configuration to use ubuntu-latest --------- Co-authored-by: benshuk <bens@ai21.com>
1 parent 624fa7c commit 81f4209

File tree

3 files changed

+13
-11
lines changed

3 files changed

+13
-11
lines changed

.github/workflows/semantic-pr.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ on:
1313

1414
jobs:
1515
semantic-pr:
16-
runs-on: ubuntu-20.04
16+
runs-on: ubuntu-latest
1717
timeout-minutes: 1
1818
steps:
1919
- name: Semantic pull-request

ai21/clients/common/maestro/run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def retrieve(self, run_id: str) -> RunResponse:
6666
pass
6767

6868
@abstractmethod
69-
def _poll_for_status(self, *, run_id: str, poll_interval: float, poll_timeout: float) -> RunResponse:
69+
def poll_for_status(self, *, run_id: str, poll_interval_sec: float, poll_timeout_sec: float) -> RunResponse:
7070
pass
7171

7272
@abstractmethod

ai21/clients/studio/resources/maestro/run.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def retrieve(
5353
) -> RunResponse:
5454
return self._get(path=f"/{self._module_name}/{run_id}", response_cls=RunResponse)
5555

56-
def _poll_for_status(self, *, run_id: str, poll_interval: float, poll_timeout: float) -> RunResponse:
56+
def poll_for_status(self, *, run_id: str, poll_interval_sec: float, poll_timeout_sec: float) -> RunResponse:
5757
start_time = time.time()
5858

5959
while True:
@@ -62,10 +62,10 @@ def _poll_for_status(self, *, run_id: str, poll_interval: float, poll_timeout: f
6262
if run.status in TERMINATED_RUN_STATUSES:
6363
return run
6464

65-
if (time.time() - start_time) >= poll_timeout:
65+
if (time.time() - start_time) >= poll_timeout_sec:
6666
return run
6767

68-
time.sleep(poll_interval)
68+
time.sleep(poll_interval_sec)
6969

7070
def create_and_poll(
7171
self,
@@ -92,7 +92,9 @@ def create_and_poll(
9292
**kwargs,
9393
)
9494

95-
return self._poll_for_status(run_id=run.id, poll_interval=poll_interval_sec, poll_timeout=poll_timeout_sec)
95+
return self.poll_for_status(
96+
run_id=run.id, poll_interval_sec=poll_interval_sec, poll_timeout_sec=poll_timeout_sec
97+
)
9698

9799

98100
class AsyncMaestroRun(AsyncStudioResource, BaseMaestroRun):
@@ -127,7 +129,7 @@ async def retrieve(
127129
) -> RunResponse:
128130
return await self._get(path=f"/{self._module_name}/{run_id}", response_cls=RunResponse)
129131

130-
async def _poll_for_status(self, *, run_id: str, poll_interval: float, poll_timeout: float) -> RunResponse:
132+
async def poll_for_status(self, *, run_id: str, poll_interval_sec: float, poll_timeout_sec: float) -> RunResponse:
131133
start_time = time.time()
132134

133135
while True:
@@ -136,10 +138,10 @@ async def _poll_for_status(self, *, run_id: str, poll_interval: float, poll_time
136138
if run.status in TERMINATED_RUN_STATUSES:
137139
return run
138140

139-
if (time.time() - start_time) >= poll_timeout:
141+
if (time.time() - start_time) >= poll_timeout_sec:
140142
return run
141143

142-
await asyncio.sleep(poll_interval)
144+
await asyncio.sleep(poll_interval_sec)
143145

144146
async def create_and_poll(
145147
self,
@@ -166,6 +168,6 @@ async def create_and_poll(
166168
**kwargs,
167169
)
168170

169-
return await self._poll_for_status(
170-
run_id=run.id, poll_interval=poll_interval_sec, poll_timeout=poll_timeout_sec
171+
return await self.poll_for_status(
172+
run_id=run.id, poll_interval_sec=poll_interval_sec, poll_timeout_sec=poll_timeout_sec
171173
)

0 commit comments

Comments
 (0)