-
Notifications
You must be signed in to change notification settings - Fork 7.2k
[RLlib] Clean up offline prelearner and its unit testing #60632
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[RLlib] Clean up offline prelearner and its unit testing #60632
Conversation
|
|
||
| # If multi-agent we need to extract the agent ID. | ||
| # TODO (simon): Check, what happens with the module ID. | ||
| if is_multi_agent: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
multi agent case is not implemented
| to_numpy: bool = False, | ||
| input_compress_columns: Optional[List[str]] = None, | ||
| **kwargs: Dict[str, Any], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need these parameters as arguments of this method because they can be expected to be constant over the lifetime of an OfflinePreLearner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request effectively cleans up the OfflinePreLearner class and its associated unit tests, leading to improved maintainability and faster test execution. The deprecation of ignore_final_observation and the refactoring of method signatures are well-executed. However, the refactoring has introduced breaking changes in some unit tests by converting static methods to instance methods without updating their call sites in the tests. These issues are critical and need to be addressed. I've also identified a minor type hint inaccuracy that should be corrected.
| def _map_to_episodes( | ||
| is_multi_agent: bool, | ||
| self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing _map_to_episodes from a static method to an instance method breaks the test_offline_prelearner_convert_to_episodes unit test in rllib/offline/tests/test_offline_prelearner.py. The test still calls this method statically (OfflinePreLearner._map_to_episodes(...)), which will now fail. The test needs to be updated to instantiate OfflinePreLearner and then call this method on the instance.
| def _map_sample_batch_to_episode( | ||
| is_multi_agent: bool, | ||
| self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing _map_sample_batch_to_episode from a static method to an instance method breaks the test_offline_prelearner_convert_from_old_sample_batch_to_episodes unit test in rllib/offline/tests/test_offline_prelearner.py. The test still calls this method statically (OfflinePreLearner._map_sample_batch_to_episode(...)), which will now fail. The test needs to be updated to instantiate OfflinePreLearner and then call this method on the instance.
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| def _validate_deprecated_map_args(kwargs: dict, config: "AlgorithmConfig") -> Set: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return type hint for this function is Set, but it returns a tuple of three elements (is_multi_agent, schema, input_compress_columns). This should be corrected to Tuple[bool, Dict, List] for better type safety and clarity.
| def _validate_deprecated_map_args(kwargs: dict, config: "AlgorithmConfig") -> Set: | |
| def _validate_deprecated_map_args(kwargs: dict, config: "AlgorithmConfig") -> Tuple[bool, Dict, List]: |
| unpacked_obs = ( | ||
| unpack_if_needed(obs) | ||
| if Columns.OBS in input_compress_columns | ||
| else obs | ||
| ) | ||
| # Set the next observation. | ||
| if ignore_final_observation: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ignore final observation is not tested and we don't use it ourselves anywhere in the codebase.
| # If multi-agent we need to extract the agent ID. | ||
| # TODO (simon): Check, what happens with the module ID. | ||
| if is_multi_agent: | ||
| agent_id = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
multi agent case is not implemented
| input_compress_columns: Optional[List[str]] = None, | ||
| ignore_final_observation: Optional[bool] = False, | ||
| observation_space: gym.Space = None, | ||
| action_space: gym.Space = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need these parameters as arguments of this method because they can be expected to be constant over the lifetime of an OfflinePreLearner.
|
|
||
| # Run the `Learner`'s connector pipeline. | ||
| batch = self._learner_connector( | ||
| rl_module=self._module, | ||
| batch={}, | ||
| episodes=episodes, | ||
| shared_data={}, | ||
| # TODO (sven): Add MetricsLogger to non-Learner components that have a | ||
| # LearnerConnector pipeline. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment moved to group with other TODOs
| to_numpy=False, | ||
| input_compress_columns=self.config.input_compress_columns, | ||
| observation_space=self.observation_space, | ||
| action_space=self.action_space, | ||
| )["episodes"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The effect of removing these args can be seen here. 5 of the 7 arguments are constant so no need to parameterize.
| self.config: AlgorithmConfig = config | ||
| self.input_read_episodes: bool = self.config.input_read_episodes | ||
| self.input_read_sample_batches: bool = self.config.input_read_sample_batches |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These parameters are only used in 1 or max 2 places, so we can just access config there and keep the constructor clean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cursor Bugbot has reviewed your changes and found 6 potential issues.
|
|
||
| # Set to empty list, if `None`. | ||
| input_compress_columns = input_compress_columns or [] | ||
| is_multi_agent, schema, input_compress_columns = _validate_deprecated_map_args(kwargs, self) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function receives wrong object type, expects AlgorithmConfig
High Severity
The _validate_deprecated_map_args function is called with self (an OfflinePreLearner instance) but the function expects an AlgorithmConfig and tries to access config.is_multi_agent, config.input_read_schema, and config.input_compress_columns. These attributes don't exist on OfflinePreLearner—they exist on self.config. This will cause AttributeError whenever _map_to_episodes or _map_sample_batch_to_episode is called without deprecated args in kwargs.
Additional Locations (1)
| self._is_multi_agent, | ||
| batch, | ||
| to_numpy=True, | ||
| schema=SCHEMA | self.config.input_read_schema, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OfflinePolicyPreEvaluator references removed parent class attributes
High Severity
OfflinePolicyPreEvaluator extends OfflinePreLearner and its __call__ method references self.input_read_episodes, self.input_read_sample_batches, and self._is_multi_agent. These instance attributes were removed from OfflinePreLearner.__init__ in this PR (now accessed via self.config.input_read_episodes etc.). This will cause AttributeError when the evaluator runs.
Additional Locations (1)
| episodes = self._postprocess_and_sample(episodes) | ||
|
|
||
| # Else, if we have old stack `SampleBatch`es. | ||
| elif self.input_read_sample_batches: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Buffer kwargs merge order reverses user/default priority
Medium Severity
The buffer kwargs merge order is reversed. The old code used defaults | user_config (user overrides defaults), but the new code uses user_config | defaults (defaults override user). This means user-specified capacity or batch_size_B values in prelearner_buffer_kwargs will be silently ignored and overwritten by the hardcoded defaults.
| self.assertTrue(len(episodes) == 10) | ||
| self.assertTrue(isinstance(episodes[0], SingleAgentEpisode)) | ||
|
|
||
| def test_offline_prelearner_ignore_final_observation(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test calls instance method as static method
Medium Severity
The test calls OfflinePreLearner._map_to_episodes(False, batch) as if it were a static method, but _map_to_episodes was changed from a @staticmethod to an instance method in this PR. The boolean False will be bound to self, causing the method to fail when accessing self.config since a boolean doesn't have that attribute.
| # if refs: | ||
| # module_state = ray.get(self._future) | ||
| # | ||
| # self._module.set_state(module_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OfflinePreLearner.call calls instance method as class method
High Severity
In OfflinePreLearner.__call__, the call OfflinePreLearner._map_sample_batch_to_episode(batch, to_numpy=True) invokes the method on the class rather than the instance. Since _map_sample_batch_to_episode is now an instance method, the batch dict gets bound to self, leaving the required batch parameter unsatisfied. This will cause TypeError: missing 1 required positional argument: 'batch'. The call should be self._map_sample_batch_to_episode(batch, to_numpy=True).


Description
offline pre learner unit tests are timing out often.
This is because
test_offline_prelearner_sample_from_episode_datatakes 2 minutes on my macbook pro because it collects many samples. I'm not sure how long it takes on CI, but appears to be long enough to time out often. It also uses two env runners by default, which results in two Ray Data datasets executed at the same time for writing, which spawns too many tasks on my dev machine for unittesting and it freezes while the test is running (same for @pseudo-rnd-thoughts ).Therefore, this PR reduces test runtime from >2 minutes to 8 seconds on my MBP and uses less resources with only one env runner. The PR also cleans up the OfflinePreLearnerClass to make it more maintainable for upcoming changes.
Removes >150loc. Added lines are mostly handling deprecation.