@@ -61,11 +61,7 @@ def save_token_with_org_info(self, token: str) -> None:
61
61
# Store ALL organizations in cache for local resolution
62
62
all_orgs = [{"id" : org .get ("id" ), "name" : org .get ("name" )} for org in orgs ]
63
63
primary_org = orgs [0 ] # Use first org as primary/default
64
- auth_data ["organization" ] = {
65
- "id" : primary_org .get ("id" ),
66
- "name" : primary_org .get ("name" ),
67
- "all_orgs" : all_orgs
68
- }
64
+ auth_data ["organization" ] = {"id" : primary_org .get ("id" ), "name" : primary_org .get ("name" ), "all_orgs" : all_orgs }
69
65
auth_data ["organizations_cache" ] = all_orgs # Separate cache for easy access
70
66
71
67
except requests .RequestException as e :
@@ -180,7 +176,7 @@ def get_user_info(self) -> dict | None:
180
176
181
177
def get_cached_organizations (self ) -> list [dict ] | None :
182
178
"""Get all cached organizations.
183
-
179
+
184
180
Returns:
185
181
List of organization dictionaries with 'id' and 'name' keys, or None if no cache.
186
182
"""
@@ -194,37 +190,75 @@ def get_cached_organizations(self) -> list[dict] | None:
194
190
195
191
def is_org_id_in_cache (self , org_id : int ) -> bool :
196
192
"""Check if an organization ID exists in the local cache.
197
-
193
+
198
194
Args:
199
195
org_id: The organization ID to check
200
-
196
+
201
197
Returns:
202
198
True if the organization ID is found in cache, False otherwise.
203
199
"""
204
200
cached_orgs = self .get_cached_organizations ()
205
201
if not cached_orgs :
206
202
return False
207
-
203
+
208
204
return any (org .get ("id" ) == org_id for org in cached_orgs )
209
205
210
206
def get_org_name_from_cache (self , org_id : int ) -> str | None :
211
207
"""Get organization name from cache by ID.
212
-
208
+
213
209
Args:
214
210
org_id: The organization ID to look up
215
-
211
+
216
212
Returns:
217
213
Organization name if found in cache, None otherwise.
218
214
"""
219
215
cached_orgs = self .get_cached_organizations ()
220
216
if not cached_orgs :
221
217
return None
222
-
218
+
223
219
for org in cached_orgs :
224
220
if org .get ("id" ) == org_id :
225
221
return org .get ("name" )
226
222
return None
227
223
224
+ def set_default_organization (self , org_id : int , org_name : str ) -> None :
225
+ """Set the default organization in auth.json.
226
+
227
+ Args:
228
+ org_id: The organization ID to set as default
229
+ org_name: The organization name
230
+ """
231
+ auth_data = self .get_auth_data ()
232
+ if not auth_data :
233
+ msg = "No authentication data found. Please run 'codegen login' first."
234
+ raise ValueError (msg )
235
+
236
+ # Verify the org exists in cache
237
+ if not self .is_org_id_in_cache (org_id ):
238
+ msg = f"Organization { org_id } not found in cache. Please run 'codegen login' to refresh."
239
+ raise ValueError (msg )
240
+
241
+ # Update the organization info
242
+ auth_data ["organization" ] = {"id" : org_id , "name" : org_name , "all_orgs" : auth_data .get ("organization" , {}).get ("all_orgs" , [])}
243
+
244
+ # Save to file
245
+ try :
246
+ import json
247
+
248
+ with open (self .token_file , "w" ) as f :
249
+ json .dump (auth_data , f , indent = 2 )
250
+
251
+ # Secure the file permissions (read/write for owner only)
252
+ os .chmod (self .token_file , 0o600 )
253
+
254
+ # Invalidate cache
255
+ global _token_cache , _cache_mtime
256
+ _token_cache = None
257
+ _cache_mtime = None
258
+ except Exception as e :
259
+ msg = f"Error saving default organization: { e } "
260
+ raise ValueError (msg )
261
+
228
262
229
263
def get_current_token () -> str | None :
230
264
"""Get the current authentication token if one exists.
@@ -289,7 +323,7 @@ def get_current_org_name() -> str | None:
289
323
290
324
def get_cached_organizations () -> list [dict ] | None :
291
325
"""Get all cached organizations.
292
-
326
+
293
327
Returns:
294
328
List of organization dictionaries with 'id' and 'name' keys, or None if no cache.
295
329
"""
@@ -299,10 +333,10 @@ def get_cached_organizations() -> list[dict] | None:
299
333
300
334
def is_org_id_cached (org_id : int ) -> bool :
301
335
"""Check if an organization ID exists in the local cache.
302
-
336
+
303
337
Args:
304
338
org_id: The organization ID to check
305
-
339
+
306
340
Returns:
307
341
True if the organization ID is found in cache, False otherwise.
308
342
"""
@@ -312,10 +346,10 @@ def is_org_id_cached(org_id: int) -> bool:
312
346
313
347
def get_org_name_from_cache (org_id : int ) -> str | None :
314
348
"""Get organization name from cache by ID.
315
-
349
+
316
350
Args:
317
351
org_id: The organization ID to look up
318
-
352
+
319
353
Returns:
320
354
Organization name if found in cache, None otherwise.
321
355
"""
@@ -335,9 +369,10 @@ def get_current_user_info() -> dict | None:
335
369
336
370
# Repository caching functions (similar to organization caching)
337
371
372
+
338
373
def get_cached_repositories () -> list [dict ] | None :
339
374
"""Get all cached repositories.
340
-
375
+
341
376
Returns:
342
377
List of repository dictionaries with 'id' and 'name' keys, or None if no cache.
343
378
"""
@@ -350,7 +385,7 @@ def get_cached_repositories() -> list[dict] | None:
350
385
351
386
def cache_repositories (repositories : list [dict ]) -> None :
352
387
"""Cache repositories to local storage.
353
-
388
+
354
389
Args:
355
390
repositories: List of repository dictionaries to cache
356
391
"""
@@ -361,53 +396,65 @@ def cache_repositories(repositories: list[dict]) -> None:
361
396
# Save back to file
362
397
try :
363
398
import json
364
- with open (token_manager .token_file , 'w' ) as f :
399
+
400
+ with open (token_manager .token_file , "w" ) as f :
365
401
json .dump (auth_data , f , indent = 2 )
366
402
except Exception :
367
403
pass # Fail silently
368
404
369
405
370
406
def is_repo_id_cached (repo_id : int ) -> bool :
371
407
"""Check if a repository ID exists in the local cache.
372
-
408
+
373
409
Args:
374
410
repo_id: The repository ID to check
375
-
411
+
376
412
Returns:
377
413
True if the repository ID is found in cache, False otherwise.
378
414
"""
379
415
cached_repos = get_cached_repositories ()
380
416
if not cached_repos :
381
417
return False
382
-
418
+
383
419
return any (repo .get ("id" ) == repo_id for repo in cached_repos )
384
420
385
421
386
422
def get_repo_name_from_cache (repo_id : int ) -> str | None :
387
423
"""Get repository name from cache by ID.
388
-
424
+
389
425
Args:
390
426
repo_id: The repository ID to look up
391
-
427
+
392
428
Returns:
393
429
Repository name if found in cache, None otherwise.
394
430
"""
395
431
cached_repos = get_cached_repositories ()
396
432
if not cached_repos :
397
433
return None
398
-
434
+
399
435
for repo in cached_repos :
400
436
if repo .get ("id" ) == repo_id :
401
437
return repo .get ("name" )
402
-
438
+
403
439
return None
404
440
405
441
406
442
def get_current_repo_name () -> str | None :
407
443
"""Get the current repository name from environment or cache."""
408
444
from codegen .cli .utils .repo import get_current_repo_id
409
-
445
+
410
446
repo_id = get_current_repo_id ()
411
447
if repo_id :
412
448
return get_repo_name_from_cache (repo_id )
413
449
return None
450
+
451
+
452
+ def set_default_organization (org_id : int , org_name : str ) -> None :
453
+ """Set the default organization in auth.json.
454
+
455
+ Args:
456
+ org_id: The organization ID to set as default
457
+ org_name: The organization name
458
+ """
459
+ token_manager = TokenManager ()
460
+ return token_manager .set_default_organization (org_id , org_name )
0 commit comments