diff --git a/tools/rl_training_tool.py b/tools/rl_training_tool.py index 566a2fb3..41559db4 100644 --- a/tools/rl_training_tool.py +++ b/tools/rl_training_tool.py @@ -37,12 +37,15 @@ import subprocess import sys import time import uuid +import logging from datetime import datetime import yaml from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional +logger = logging.getLogger(__name__) + # ============================================================================ # Path Configuration # ============================================================================ @@ -206,7 +209,7 @@ def _scan_environments() -> List[EnvironmentInfo]: )) break except Exception as e: - print(f"Warning: Could not parse {py_file}: {e}") + logger.warning("Could not parse %s: %s", py_file, e) return environments @@ -243,7 +246,7 @@ def _get_env_config_fields(env_file_path: str) -> Dict[str, Dict[str, Any]]: config_class = type(env_config) except Exception as config_error: # Fallback: try to import BaseEnvConfig directly from atroposlib - print(f"Note: config_init failed ({config_error}), using BaseEnvConfig defaults") + logger.info("config_init failed (%s), using BaseEnvConfig defaults", config_error) try: from atroposlib.envs.base import BaseEnvConfig config_class = BaseEnvConfig @@ -291,7 +294,7 @@ def _get_env_config_fields(env_file_path: str) -> Dict[str, Dict[str, Any]]: return fields except Exception as e: - print(f"Warning: Could not introspect environment config: {e}") + logger.warning("Could not introspect environment config: %s", e) return {} @@ -324,7 +327,7 @@ async def _spawn_training_run(run_state: RunState, config_path: Path): try: # Step 1: Start the Atropos API server (run-api) - print(f"[{run_id}] Starting Atropos API server (run-api)...") + logger.info("[%s] Starting Atropos API server (run-api)...", run_id) # File must stay open while the subprocess runs; we store the handle # on run_state so _stop_training_run() can close it when done. @@ -346,10 +349,10 @@ async def _spawn_training_run(run_state: RunState, config_path: Path): _stop_training_run(run_state) return - print(f"[{run_id}] Atropos API server started") + logger.info("[%s] Atropos API server started", run_id) # Step 2: Start the Tinker trainer - print(f"[{run_id}] Starting Tinker trainer: launch_training.py --config {config_path}") + logger.info("[%s] Starting Tinker trainer: launch_training.py --config %s", run_id, config_path) trainer_log_file = open(trainer_log, "w") # closed by _stop_training_run run_state.trainer_log_file = trainer_log_file @@ -362,7 +365,7 @@ async def _spawn_training_run(run_state: RunState, config_path: Path): ) # Wait for trainer to initialize (it starts FastAPI inference server on 8001) - print(f"[{run_id}] Waiting 30 seconds for trainer to initialize...") + logger.info("[%s] Waiting 30 seconds for trainer to initialize...", run_id) await asyncio.sleep(30) if run_state.trainer_process.poll() is not None: @@ -371,10 +374,10 @@ async def _spawn_training_run(run_state: RunState, config_path: Path): _stop_training_run(run_state) return - print(f"[{run_id}] Trainer started, inference server on port 8001") + logger.info("[%s] Trainer started, inference server on port 8001", run_id) # Step 3: Start the environment - print(f"[{run_id}] Waiting 90 more seconds before starting environment...") + logger.info("[%s] Waiting 90 more seconds before starting environment...", run_id) await asyncio.sleep(90) # Find the environment file @@ -390,7 +393,7 @@ async def _spawn_training_run(run_state: RunState, config_path: Path): _stop_training_run(run_state) return - print(f"[{run_id}] Starting environment: {env_info.file_path} serve") + logger.info("[%s] Starting environment: %s serve", run_id, env_info.file_path) env_log_file = open(env_log, "w") # closed by _stop_training_run run_state.env_log_file = env_log_file @@ -412,7 +415,7 @@ async def _spawn_training_run(run_state: RunState, config_path: Path): run_state.status = "running" run_state.start_time = time.time() - print(f"[{run_id}] Training run started successfully!") + logger.info("[%s] Training run started successfully!", run_id) # Start background monitoring asyncio.create_task(_monitor_training_run(run_state)) @@ -460,7 +463,7 @@ def _stop_training_run(run_state: RunState): """Stop all processes for a training run.""" # Stop in reverse order: env -> trainer -> api if run_state.env_process and run_state.env_process.poll() is None: - print(f"[{run_state.run_id}] Stopping environment process...") + logger.info("[%s] Stopping environment process...", run_state.run_id) run_state.env_process.terminate() try: run_state.env_process.wait(timeout=10) @@ -468,7 +471,7 @@ def _stop_training_run(run_state: RunState): run_state.env_process.kill() if run_state.trainer_process and run_state.trainer_process.poll() is None: - print(f"[{run_state.run_id}] Stopping trainer process...") + logger.info("[%s] Stopping trainer process...", run_state.run_id) run_state.trainer_process.terminate() try: run_state.trainer_process.wait(timeout=10) @@ -476,7 +479,7 @@ def _stop_training_run(run_state: RunState): run_state.trainer_process.kill() if run_state.api_process and run_state.api_process.poll() is None: - print(f"[{run_state.run_id}] Stopping API server...") + logger.info("[%s] Stopping API server...", run_state.run_id) run_state.api_process.terminate() try: run_state.api_process.wait(timeout=10)