Algorithm Registry
Algorithm Registry API — Advantage estimators, policy loss registries.
Base Registry Classes
The registry system provides a way to register and manage custom algorithm functions across distributed Ray environments.
class BaseFunctionRegistry
Base class for function registries with Ray actor synchronization.
Functions:
| Name | Description |
|---|---|
sync_with_actor | Sync local registry with Ray actor if Ray is available. |
register | Register a function. |
get | Get a function by name. |
list_available | List all registered functions. |
unregister | Unregister a function. Useful for testing. |
reset | Resets the registry (useful for testing purposes). |
repopulate | Repopulate the registry with the default functions. |
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:222-423
class BaseFunctionRegistry:
"""Base class for function registries with Ray actor synchronization."""
# Subclasses should override these class attributes
_actor_name = None
_function_type = "Function"
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
cls._functions = {}
cls._ray_actor = None
cls._synced_to_actor = False
@classmethod
def _get_or_create_actor(cls):
"""Get or create the Ray actor for managing the registry using get_if_exists."""
if not ray.is_initialized():
raise Exception("Ray is not initialized, cannot create registry actor")
if cls._ray_actor is None:
# Use get_if_exists to create actor only if it doesn't exist
cls._ray_actor = RegistryActor.options(name=cls._actor_name, get_if_exists=True).remote()
return cls._ray_actor
@classmethod
def _sync_local_to_actor(cls):
"""Sync all local functions to Ray actor."""
if cls._synced_to_actor:
return
if not ray.is_initialized():
raise Exception("Ray is not initialized, cannot sync with actor")
try:
actor = cls._get_or_create_actor()
if actor is not None:
for name, func in cls._functions.items():
func_serialized = cloudpickle.dumps(func)
ray.get(actor.register.remote(name, func_serialized))
cls._synced_to_actor = True
except Exception as e:
logger.error(f"Error syncing {cls._function_type} to actor: {e}")
raise e
@classmethod
def sync_with_actor(cls):
"""Sync local registry with Ray actor if Ray is available."""
# Only try if Ray is initialized
if not ray.is_initialized():
raise Exception("Ray is not initialized, cannot sync with actor")
# First check if the actor is still alive
# NOTE(Charlie): This is mainly for unit tests, where we run multiple unit tests in the
# same Python process, and each unit test has ray init/shutdown. This makes cls's attributes
# outdated (e.g. the _ray_actor points to a stale actor in the previous ray session).
try:
_ = ray.get_actor(cls._actor_name) # this raises exception if the actor is stale
except ValueError:
cls._ray_actor = None
cls._synced_to_actor = False
# First, sync our local functions to the actor
cls._sync_local_to_actor()
actor = cls._get_or_create_actor()
if actor is None:
return
available = ray.get(actor.list_available.remote())
# Sync any new functions from actor to local registry
for name in available:
if name not in cls._functions:
func_serialized = ray.get(actor.get.remote(name))
if func_serialized is not None:
# Deserialize the function
try:
func = cloudpickle.loads(func_serialized)
cls._functions[name] = func
except Exception as e:
# If deserialization fails, skip this function
logger.error(f"Error deserializing {name} from actor: {e}")
raise e
@classmethod
def register(cls, name: Union[str, StrEnum], func: Callable):
"""Register a function.
If ray is initialized, this function will get or create a named ray actor (RegistryActor)
for the registry, and sync the registry to the actor.
If ray is not initalized, the function will be stored in the local registry only.
To make sure all locally registered functions are available to all ray processes,
call sync_with_actor() after ray.init().
Args:
name: Name of the function to register. Can be a string or a StrEnum.
func: Function to register.
Raises:
ValueError: If the function is already registered.
"""
# Convert enum to string if needed
# note: StrEnum is not cloudpickleable: https://github.com/cloudpipe/cloudpickle/issues/558
if isinstance(name, StrEnum):
name = name.value
if name in cls._functions:
raise ValueError(f"{cls._function_type} '{name}' already registered")
# Always store in local registry first
cls._functions[name] = func
# Try to sync with Ray actor if Ray is initialized
if ray.is_initialized():
actor = cls._get_or_create_actor()
if actor is not None:
# Serialize the function using cloudpickle
func_serialized = cloudpickle.dumps(func)
ray.get(actor.register.remote(name, func_serialized))
@classmethod
def get(cls, name: str) -> Callable:
"""Get a function by name.
If ray is initialized, this function will first sync the local registry with the RegistryActor.
Then it will return the function if it is found in the registry.
Args:
name: Name of the function to get. Can be a string or a StrEnum.
Returns:
The function if it is found in the registry.
"""
# Try to sync with actor first if Ray is available
if ray.is_initialized():
cls.sync_with_actor()
if name not in cls._functions:
available = list(cls._functions.keys())
raise ValueError(f"Unknown {cls._function_type.lower()} '{name}'. Available: {available}")
return cls._functions[name]
@classmethod
def list_available(cls) -> List[str]:
"""List all registered functions."""
# Try to sync with actor first if Ray is available
if ray.is_initialized():
cls.sync_with_actor()
return list(cls._functions.keys())
@classmethod
def unregister(cls, name: Union[str, StrEnum]):
"""Unregister a function. Useful for testing."""
# Convert enum to string if needed
if isinstance(name, StrEnum):
name = name.value
# Try to sync with actor first to get any functions that might be in the actor but not local
if ray.is_initialized():
cls.sync_with_actor()
# Track if we found the function anywhere
found_locally = name in cls._functions
found_in_actor = False
# Remove from local registry if it exists
if found_locally:
del cls._functions[name]
# Try to remove from Ray actor if Ray is available
if ray.is_initialized():
actor = cls._get_or_create_actor()
if actor is not None:
# Check if it exists in actor first
available_in_actor = ray.get(actor.list_available.remote())
if name in available_in_actor:
found_in_actor = True
ray.get(actor.unregister.remote(name))
# Only raise error if the function wasn't found anywhere
if not found_locally and not found_in_actor:
raise ValueError(f"{cls._function_type} '{name}' not registered")
@classmethod
def reset(cls):
"""Resets the registry (useful for testing purposes)."""
if ray.is_initialized() and cls._ray_actor is not None:
try:
actor = ray.get_actor(cls._actor_name) # this raises exception if the actor is stale
ray.kill(actor)
except ValueError:
pass # Actor may already be gone
cls._functions.clear()
cls._ray_actor = None
cls._synced_to_actor = False
@classmethod
def repopulate(cls):
"""Repopulate the registry with the default functions."""
cls.reset()
cls.register(cls._function_type, cls._function_type)method classmethod sync_with_actor
sync_with_actor()Sync local registry with Ray actor if Ray is available.
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:265-303
@classmethod
def sync_with_actor(cls):
"""Sync local registry with Ray actor if Ray is available."""
# Only try if Ray is initialized
if not ray.is_initialized():
raise Exception("Ray is not initialized, cannot sync with actor")
# First check if the actor is still alive
# NOTE(Charlie): This is mainly for unit tests, where we run multiple unit tests in the
# same Python process, and each unit test has ray init/shutdown. This makes cls's attributes
# outdated (e.g. the _ray_actor points to a stale actor in the previous ray session).
try:
_ = ray.get_actor(cls._actor_name) # this raises exception if the actor is stale
except ValueError:
cls._ray_actor = None
cls._synced_to_actor = False
# First, sync our local functions to the actor
cls._sync_local_to_actor()
actor = cls._get_or_create_actor()
if actor is None:
return
available = ray.get(actor.list_available.remote())
# Sync any new functions from actor to local registry
for name in available:
if name not in cls._functions:
func_serialized = ray.get(actor.get.remote(name))
if func_serialized is not None:
# Deserialize the function
try:
func = cloudpickle.loads(func_serialized)
cls._functions[name] = func
except Exception as e:
# If deserialization fails, skip this function
logger.error(f"Error deserializing {name} from actor: {e}")
raise emethod register
register(name: Union[str, StrEnum], func: Callable)Register a function.
If ray is initialized, this function will get or create a named ray actor (RegistryActor) for the registry, and sync the registry to the actor.
If ray is not initalized, the function will be stored in the local registry only.
To make sure all locally registered functions are available to all ray processes, call sync_with_actor() after ray.init().
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name | Union[str, StrEnum] | Name of the function to register. Can be a string or a StrEnum. | required |
func | Callable | Function to register. | required |
Raises:
| Type | Description |
|---|---|
| ValueError | If the function is already registered. |
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:305-341
@classmethod
def register(cls, name: Union[str, StrEnum], func: Callable):
"""Register a function.
If ray is initialized, this function will get or create a named ray actor (RegistryActor)
for the registry, and sync the registry to the actor.
If ray is not initalized, the function will be stored in the local registry only.
To make sure all locally registered functions are available to all ray processes,
call sync_with_actor() after ray.init().
Args:
name: Name of the function to register. Can be a string or a StrEnum.
func: Function to register.
Raises:
ValueError: If the function is already registered.
"""
# Convert enum to string if needed
# note: StrEnum is not cloudpickleable: https://github.com/cloudpipe/cloudpickle/issues/558
if isinstance(name, StrEnum):
name = name.value
if name in cls._functions:
raise ValueError(f"{cls._function_type} '{name}' already registered")
# Always store in local registry first
cls._functions[name] = func
# Try to sync with Ray actor if Ray is initialized
if ray.is_initialized():
actor = cls._get_or_create_actor()
if actor is not None:
# Serialize the function using cloudpickle
func_serialized = cloudpickle.dumps(func)
ray.get(actor.register.remote(name, func_serialized))method classmethod get
get(name: str) -> CallableGet a function by name.
If ray is initialized, this function will first sync the local registry with the RegistryActor. Then it will return the function if it is found in the registry.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name | str | Name of the function to get. Can be a string or a StrEnum. | required |
Returns:
| Type | Description |
|---|---|
| Callable | The function if it is found in the registry. |
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:343-363
@classmethod
def get(cls, name: str) -> Callable:
"""Get a function by name.
If ray is initialized, this function will first sync the local registry with the RegistryActor.
Then it will return the function if it is found in the registry.
Args:
name: Name of the function to get. Can be a string or a StrEnum.
Returns:
The function if it is found in the registry.
"""
# Try to sync with actor first if Ray is available
if ray.is_initialized():
cls.sync_with_actor()
if name not in cls._functions:
available = list(cls._functions.keys())
raise ValueError(f"Unknown {cls._function_type.lower()} '{name}'. Available: {available}")
return cls._functions[name]method classmethod list_available
list_available() -> List[str]List all registered functions.
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:365-371
@classmethod
def list_available(cls) -> List[str]:
"""List all registered functions."""
# Try to sync with actor first if Ray is available
if ray.is_initialized():
cls.sync_with_actor()
return list(cls._functions.keys())method classmethod unregister
unregister(name: Union[str, StrEnum])Unregister a function. Useful for testing.
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:373-404
@classmethod
def unregister(cls, name: Union[str, StrEnum]):
"""Unregister a function. Useful for testing."""
# Convert enum to string if needed
if isinstance(name, StrEnum):
name = name.value
# Try to sync with actor first to get any functions that might be in the actor but not local
if ray.is_initialized():
cls.sync_with_actor()
# Track if we found the function anywhere
found_locally = name in cls._functions
found_in_actor = False
# Remove from local registry if it exists
if found_locally:
del cls._functions[name]
# Try to remove from Ray actor if Ray is available
if ray.is_initialized():
actor = cls._get_or_create_actor()
if actor is not None:
# Check if it exists in actor first
available_in_actor = ray.get(actor.list_available.remote())
if name in available_in_actor:
found_in_actor = True
ray.get(actor.unregister.remote(name))
# Only raise error if the function wasn't found anywhere
if not found_locally and not found_in_actor:
raise ValueError(f"{cls._function_type} '{name}' not registered")method classmethod reset
reset()Resets the registry (useful for testing purposes).
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:406-417
@classmethod
def reset(cls):
"""Resets the registry (useful for testing purposes)."""
if ray.is_initialized() and cls._ray_actor is not None:
try:
actor = ray.get_actor(cls._actor_name) # this raises exception if the actor is stale
ray.kill(actor)
except ValueError:
pass # Actor may already be gone
cls._functions.clear()
cls._ray_actor = None
cls._synced_to_actor = Falsemethod classmethod repopulate
repopulate()Repopulate the registry with the default functions.
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:419-423
@classmethod
def repopulate(cls):
"""Repopulate the registry with the default functions."""
cls.reset()
cls.register(cls._function_type, cls._function_type)class RegistryActor
RegistryActor()Shared Ray actor for managing function registries across processes.
Functions:
| Name | Description |
|---|---|
register | Register a serialized function. |
get | Get a serialized function by name. |
list_available | List all available function names. |
unregister | Unregister a function by name. |
Attributes:
| Name | Type | Description |
|---|---|---|
registry |
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:198-219
@ray.remote
class RegistryActor:
"""Shared Ray actor for managing function registries across processes."""
def __init__(self):
self.registry = {}
def register(self, name: str, func_serialized: bytes):
"""Register a serialized function."""
self.registry[name] = func_serialized
def get(self, name: str):
"""Get a serialized function by name."""
return self.registry.get(name)
def list_available(self):
"""List all available function names."""
return list(self.registry.keys())
def unregister(self, name: str):
"""Unregister a function by name."""
return self.registry.pop(name, None)attr registry
registry = {}method register
register(name: str, func_serialized: bytes)Register a serialized function.
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:205-207
def register(self, name: str, func_serialized: bytes):
"""Register a serialized function."""
self.registry[name] = func_serializedmethod classmethod get
get(name: str)Get a serialized function by name.
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:209-211
def get(self, name: str):
"""Get a serialized function by name."""
return self.registry.get(name)method classmethod list_available
list_available()List all available function names.
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:213-215
def list_available(self):
"""List all available function names."""
return list(self.registry.keys())method classmethod unregister
unregister(name: str)Unregister a function by name.
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:217-219
def unregister(self, name: str):
"""Unregister a function by name."""
return self.registry.pop(name, None)method sync_registries
sync_registries()Sync the registries with the ray actor once ray is initialized
Advantage Estimator Registry
The advantage estimator registry manages functions that compute advantages and returns.
class AdvantageEstimatorRegistry
Bases: BaseFunctionRegistry
Registry for advantage estimator functions.
This registry allows users to register custom advantage estimators without modifying the skyrl_train package. Custom estimators can be registered by calling AdvantageEstimatorRegistry.register() directly or by using the @register_advantage_estimator decorator.
See examples/algorithms/custom_advantage_estimator for a simple example of how to register and use custom advantage estimators.
Functions:
| Name | Description |
|---|---|
sync_with_actor | Sync local registry with Ray actor if Ray is available. |
register | Register a function. |
get | Get a function by name. |
list_available | List all registered functions. |
unregister | Unregister a function. Useful for testing. |
reset | Resets the registry (useful for testing purposes). |
repopulate | Repopulate the registry with the default functions. |
repopulate_registry |
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:433-461
class AdvantageEstimatorRegistry(BaseFunctionRegistry):
"""
Registry for advantage estimator functions.
This registry allows users to register custom advantage estimators without modifying
the skyrl_train package. Custom estimators can be registered by calling
AdvantageEstimatorRegistry.register() directly or by using the @register_advantage_estimator
decorator.
See examples/algorithms/custom_advantage_estimator for a simple example of how to
register and use custom advantage estimators.
"""
_actor_name = "advantage_estimator_registry"
_function_type = "advantage estimator"
@classmethod
def repopulate_registry(cls):
ae_avail = set(cls.list_available())
ae_types = {
"grpo": [AdvantageEstimator.GRPO, compute_grpo_outcome_advantage],
"gae": [AdvantageEstimator.GAE, compute_gae_advantage_return],
"rloo": [AdvantageEstimator.RLOO, compute_rloo_outcome_advantage],
"reinforce++": [AdvantageEstimator.REINFORCE_PP, compute_reinforce_plus_plus_outcome_advantage],
}
for ae_name, (ae_type, ae_func) in ae_types.items():
if ae_name not in ae_avail:
cls.register(ae_type, ae_func)method classmethod sync_with_actor
sync_with_actor()Sync local registry with Ray actor if Ray is available.
method register
register(name: Union[str, StrEnum], func: Callable)Register a function.
If ray is initialized, this function will get or create a named ray actor (RegistryActor) for the registry, and sync the registry to the actor.
If ray is not initalized, the function will be stored in the local registry only.
To make sure all locally registered functions are available to all ray processes, call sync_with_actor() after ray.init().
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name | Union[str, StrEnum] | Name of the function to register. Can be a string or a StrEnum. | required |
func | Callable | Function to register. | required |
Raises:
| Type | Description |
|---|---|
| ValueError | If the function is already registered. |
method classmethod get
get(name: str) -> CallableGet a function by name.
If ray is initialized, this function will first sync the local registry with the RegistryActor. Then it will return the function if it is found in the registry.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name | str | Name of the function to get. Can be a string or a StrEnum. | required |
Returns:
| Type | Description |
|---|---|
| Callable | The function if it is found in the registry. |
method classmethod list_available
list_available() -> List[str]List all registered functions.
method classmethod unregister
unregister(name: Union[str, StrEnum])Unregister a function. Useful for testing.
method classmethod reset
reset()Resets the registry (useful for testing purposes).
method classmethod repopulate
repopulate()Repopulate the registry with the default functions.
method classmethod repopulate_registry
repopulate_registry()Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:449-461
@classmethod
def repopulate_registry(cls):
ae_avail = set(cls.list_available())
ae_types = {
"grpo": [AdvantageEstimator.GRPO, compute_grpo_outcome_advantage],
"gae": [AdvantageEstimator.GAE, compute_gae_advantage_return],
"rloo": [AdvantageEstimator.RLOO, compute_rloo_outcome_advantage],
"reinforce++": [AdvantageEstimator.REINFORCE_PP, compute_reinforce_plus_plus_outcome_advantage],
}
for ae_name, (ae_type, ae_func) in ae_types.items():
if ae_name not in ae_avail:
cls.register(ae_type, ae_func)class AdvantageEstimator
Bases: StrEnum
Attributes:
| Name | Type | Description |
|---|---|---|
GAE | ||
GRPO | ||
RLOO | ||
REINFORCE_PP |
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:426-430
class AdvantageEstimator(StrEnum):
GAE = "gae"
GRPO = "grpo"
RLOO = "rloo"
REINFORCE_PP = "reinforce++"attr GAE
GAE = 'gae'attr GRPO
GRPO = 'grpo'attr RLOO
RLOO = 'rloo'attr REINFORCE_PP
REINFORCE_PP = 'reinforce++'method register_advantage_estimator
register_advantage_estimator(name: Union[str, AdvantageEstimator])Decorator to register an advantage estimator function.
Policy Loss Registry
The policy loss registry manages functions that compute policy losses for PPO.
class PolicyLossRegistry
Bases: BaseFunctionRegistry
Registry for policy loss functions.
This registry allows users to register custom policy loss functions without modifying the skyrl_train package. Custom functions can be registered by calling PolicyLossRegistry.register() directly or by using the @register_policy_loss decorator.
See examples/algorithms/custom_policy_loss for a simple example of how to register and use custom policy loss functions.
Functions:
| Name | Description |
|---|---|
sync_with_actor | Sync local registry with Ray actor if Ray is available. |
register | Register a function. |
get | Get a function by name. |
list_available | List all registered functions. |
unregister | Unregister a function. Useful for testing. |
reset | Resets the registry (useful for testing purposes). |
repopulate | Repopulate the registry with the default functions. |
repopulate_registry | Repopulate the registry with default policy loss functions. |
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:476-509
class PolicyLossRegistry(BaseFunctionRegistry):
"""
Registry for policy loss functions.
This registry allows users to register custom policy loss functions without modifying
the skyrl_train package. Custom functions can be registered by calling
PolicyLossRegistry.register() directly or by using the @register_policy_loss
decorator.
See examples/algorithms/custom_policy_loss for a simple example of how to
register and use custom policy loss functions.
"""
_actor_name = "policy_loss_registry"
_function_type = "policy loss"
@classmethod
def repopulate_registry(cls):
"""Repopulate the registry with default policy loss functions."""
pl_avail = set(cls.list_available())
pl_types = {
"regular": [PolicyLossType.REGULAR, ppo_policy_loss],
"dual_clip": [PolicyLossType.DUAL_CLIP, ppo_policy_loss],
"gspo": [PolicyLossType.GSPO, gspo_policy_loss],
"clip_cov": [PolicyLossType.CLIP_COV, compute_policy_loss_clip_cov],
"kl_cov": [PolicyLossType.KL_COV, compute_policy_loss_kl_cov],
"sapo": [PolicyLossType.SAPO, sapo_policy_loss],
"cross_entropy": [PolicyLossType.CROSS_ENTROPY, cross_entropy_loss],
"importance_sampling": [PolicyLossType.IMPORTANCE_SAMPLING, importance_sampling_loss],
}
for pl_name, (pl_type, pl_func) in pl_types.items():
if pl_name not in pl_avail:
cls.register(pl_type, pl_func)method classmethod sync_with_actor
sync_with_actor()Sync local registry with Ray actor if Ray is available.
method register
register(name: Union[str, StrEnum], func: Callable)Register a function.
If ray is initialized, this function will get or create a named ray actor (RegistryActor) for the registry, and sync the registry to the actor.
If ray is not initalized, the function will be stored in the local registry only.
To make sure all locally registered functions are available to all ray processes, call sync_with_actor() after ray.init().
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name | Union[str, StrEnum] | Name of the function to register. Can be a string or a StrEnum. | required |
func | Callable | Function to register. | required |
Raises:
| Type | Description |
|---|---|
| ValueError | If the function is already registered. |
method classmethod get
get(name: str) -> CallableGet a function by name.
If ray is initialized, this function will first sync the local registry with the RegistryActor. Then it will return the function if it is found in the registry.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name | str | Name of the function to get. Can be a string or a StrEnum. | required |
Returns:
| Type | Description |
|---|---|
| Callable | The function if it is found in the registry. |
method classmethod list_available
list_available() -> List[str]List all registered functions.
method classmethod unregister
unregister(name: Union[str, StrEnum])Unregister a function. Useful for testing.
method classmethod reset
reset()Resets the registry (useful for testing purposes).
method classmethod repopulate
repopulate()Repopulate the registry with the default functions.
method classmethod repopulate_registry
repopulate_registry()Repopulate the registry with default policy loss functions.
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:492-509
@classmethod
def repopulate_registry(cls):
"""Repopulate the registry with default policy loss functions."""
pl_avail = set(cls.list_available())
pl_types = {
"regular": [PolicyLossType.REGULAR, ppo_policy_loss],
"dual_clip": [PolicyLossType.DUAL_CLIP, ppo_policy_loss],
"gspo": [PolicyLossType.GSPO, gspo_policy_loss],
"clip_cov": [PolicyLossType.CLIP_COV, compute_policy_loss_clip_cov],
"kl_cov": [PolicyLossType.KL_COV, compute_policy_loss_kl_cov],
"sapo": [PolicyLossType.SAPO, sapo_policy_loss],
"cross_entropy": [PolicyLossType.CROSS_ENTROPY, cross_entropy_loss],
"importance_sampling": [PolicyLossType.IMPORTANCE_SAMPLING, importance_sampling_loss],
}
for pl_name, (pl_type, pl_func) in pl_types.items():
if pl_name not in pl_avail:
cls.register(pl_type, pl_func)class PolicyLossType
Bases: StrEnum
Attributes:
| Name | Type | Description |
|---|---|---|
REGULAR | ||
DUAL_CLIP | ||
GSPO | ||
CISPO | ||
CLIP_COV | ||
KL_COV | ||
SAPO | ||
CROSS_ENTROPY | ||
IMPORTANCE_SAMPLING |
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:464-473
class PolicyLossType(StrEnum):
REGULAR = "regular"
DUAL_CLIP = "dual_clip"
GSPO = "gspo"
CISPO = "cispo"
CLIP_COV = "clip_cov"
KL_COV = "kl_cov"
SAPO = "sapo"
CROSS_ENTROPY = "cross_entropy"
IMPORTANCE_SAMPLING = "importance_sampling"attr REGULAR
REGULAR = 'regular'attr DUAL_CLIP
DUAL_CLIP = 'dual_clip'attr GSPO
GSPO = 'gspo'attr CISPO
CISPO = 'cispo'attr CLIP_COV
CLIP_COV = 'clip_cov'attr KL_COV
KL_COV = 'kl_cov'attr SAPO
SAPO = 'sapo'attr CROSS_ENTROPY
CROSS_ENTROPY = 'cross_entropy'attr IMPORTANCE_SAMPLING
IMPORTANCE_SAMPLING = 'importance_sampling'method register_policy_loss
register_policy_loss(name: Union[str, PolicyLossType])Decorator to register a policy loss function.