SkyRL
API ReferenceSkyRLSkyRL-Train Backend

Algorithm Registry

Algorithm Registry API — Advantage estimators, policy loss registries.

Base Registry Classes

The registry system provides a way to register and manage custom algorithm functions across distributed Ray environments.

class BaseFunctionRegistry

Base class for function registries with Ray actor synchronization.

Functions:

NameDescription
sync_with_actorSync local registry with Ray actor if Ray is available.
registerRegister a function.
getGet a function by name.
list_availableList all registered functions.
unregisterUnregister a function. Useful for testing.
resetResets the registry (useful for testing purposes).
repopulateRepopulate the registry with the default functions.
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:222-423
class BaseFunctionRegistry:
    """Base class for function registries with Ray actor synchronization."""

    # Subclasses should override these class attributes
    _actor_name = None
    _function_type = "Function"

    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        cls._functions = {}
        cls._ray_actor = None
        cls._synced_to_actor = False

    @classmethod
    def _get_or_create_actor(cls):
        """Get or create the Ray actor for managing the registry using get_if_exists."""
        if not ray.is_initialized():
            raise Exception("Ray is not initialized, cannot create registry actor")

        if cls._ray_actor is None:
            # Use get_if_exists to create actor only if it doesn't exist
            cls._ray_actor = RegistryActor.options(name=cls._actor_name, get_if_exists=True).remote()
        return cls._ray_actor

    @classmethod
    def _sync_local_to_actor(cls):
        """Sync all local functions to Ray actor."""
        if cls._synced_to_actor:
            return
        if not ray.is_initialized():
            raise Exception("Ray is not initialized, cannot sync with actor")

        try:
            actor = cls._get_or_create_actor()
            if actor is not None:
                for name, func in cls._functions.items():
                    func_serialized = cloudpickle.dumps(func)
                    ray.get(actor.register.remote(name, func_serialized))
                cls._synced_to_actor = True
        except Exception as e:
            logger.error(f"Error syncing {cls._function_type} to actor: {e}")
            raise e

    @classmethod
    def sync_with_actor(cls):
        """Sync local registry with Ray actor if Ray is available."""
        # Only try if Ray is initialized
        if not ray.is_initialized():
            raise Exception("Ray is not initialized, cannot sync with actor")

        # First check if the actor is still alive
        # NOTE(Charlie): This is mainly for unit tests, where we run multiple unit tests in the
        # same Python process, and each unit test has ray init/shutdown. This makes cls's attributes
        # outdated (e.g. the _ray_actor points to a stale actor in the previous ray session).
        try:
            _ = ray.get_actor(cls._actor_name)  # this raises exception if the actor is stale
        except ValueError:
            cls._ray_actor = None
            cls._synced_to_actor = False

        # First, sync our local functions to the actor
        cls._sync_local_to_actor()

        actor = cls._get_or_create_actor()
        if actor is None:
            return

        available = ray.get(actor.list_available.remote())

        # Sync any new functions from actor to local registry
        for name in available:
            if name not in cls._functions:
                func_serialized = ray.get(actor.get.remote(name))
                if func_serialized is not None:
                    # Deserialize the function
                    try:
                        func = cloudpickle.loads(func_serialized)
                        cls._functions[name] = func
                    except Exception as e:
                        # If deserialization fails, skip this function
                        logger.error(f"Error deserializing {name} from actor: {e}")
                        raise e

    @classmethod
    def register(cls, name: Union[str, StrEnum], func: Callable):
        """Register a function.

        If ray is initialized, this function will get or create a named ray actor (RegistryActor)
        for the registry, and sync the registry to the actor.

        If ray is not initalized, the function will be stored in the local registry only.

        To make sure all locally registered functions are available to all ray processes,
        call sync_with_actor() after ray.init().

        Args:
            name: Name of the function to register. Can be a string or a StrEnum.
            func: Function to register.

        Raises:
            ValueError: If the function is already registered.
        """
        # Convert enum to string if needed
        # note: StrEnum is not cloudpickleable: https://github.com/cloudpipe/cloudpickle/issues/558
        if isinstance(name, StrEnum):
            name = name.value

        if name in cls._functions:
            raise ValueError(f"{cls._function_type} '{name}' already registered")

        # Always store in local registry first
        cls._functions[name] = func

        # Try to sync with Ray actor if Ray is initialized
        if ray.is_initialized():
            actor = cls._get_or_create_actor()
            if actor is not None:
                # Serialize the function using cloudpickle
                func_serialized = cloudpickle.dumps(func)
                ray.get(actor.register.remote(name, func_serialized))

    @classmethod
    def get(cls, name: str) -> Callable:
        """Get a function by name.

        If ray is initialized, this function will first sync the local registry with the RegistryActor.
        Then it will return the function if it is found in the registry.

        Args:
            name: Name of the function to get. Can be a string or a StrEnum.

        Returns:
            The function if it is found in the registry.
        """
        # Try to sync with actor first if Ray is available
        if ray.is_initialized():
            cls.sync_with_actor()

        if name not in cls._functions:
            available = list(cls._functions.keys())
            raise ValueError(f"Unknown {cls._function_type.lower()} '{name}'. Available: {available}")
        return cls._functions[name]

    @classmethod
    def list_available(cls) -> List[str]:
        """List all registered functions."""
        # Try to sync with actor first if Ray is available
        if ray.is_initialized():
            cls.sync_with_actor()
        return list(cls._functions.keys())

    @classmethod
    def unregister(cls, name: Union[str, StrEnum]):
        """Unregister a function. Useful for testing."""
        # Convert enum to string if needed
        if isinstance(name, StrEnum):
            name = name.value

        # Try to sync with actor first to get any functions that might be in the actor but not local
        if ray.is_initialized():
            cls.sync_with_actor()

        # Track if we found the function anywhere
        found_locally = name in cls._functions
        found_in_actor = False

        # Remove from local registry if it exists
        if found_locally:
            del cls._functions[name]

        # Try to remove from Ray actor if Ray is available
        if ray.is_initialized():
            actor = cls._get_or_create_actor()
            if actor is not None:
                # Check if it exists in actor first
                available_in_actor = ray.get(actor.list_available.remote())
                if name in available_in_actor:
                    found_in_actor = True
                    ray.get(actor.unregister.remote(name))

        # Only raise error if the function wasn't found anywhere
        if not found_locally and not found_in_actor:
            raise ValueError(f"{cls._function_type} '{name}' not registered")

    @classmethod
    def reset(cls):
        """Resets the registry (useful for testing purposes)."""
        if ray.is_initialized() and cls._ray_actor is not None:
            try:
                actor = ray.get_actor(cls._actor_name)  # this raises exception if the actor is stale
                ray.kill(actor)
            except ValueError:
                pass  # Actor may already be gone
        cls._functions.clear()
        cls._ray_actor = None
        cls._synced_to_actor = False

    @classmethod
    def repopulate(cls):
        """Repopulate the registry with the default functions."""
        cls.reset()
        cls.register(cls._function_type, cls._function_type)

method classmethod sync_with_actor

sync_with_actor()

Sync local registry with Ray actor if Ray is available.

Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:265-303
    @classmethod
    def sync_with_actor(cls):
        """Sync local registry with Ray actor if Ray is available."""
        # Only try if Ray is initialized
        if not ray.is_initialized():
            raise Exception("Ray is not initialized, cannot sync with actor")

        # First check if the actor is still alive
        # NOTE(Charlie): This is mainly for unit tests, where we run multiple unit tests in the
        # same Python process, and each unit test has ray init/shutdown. This makes cls's attributes
        # outdated (e.g. the _ray_actor points to a stale actor in the previous ray session).
        try:
            _ = ray.get_actor(cls._actor_name)  # this raises exception if the actor is stale
        except ValueError:
            cls._ray_actor = None
            cls._synced_to_actor = False

        # First, sync our local functions to the actor
        cls._sync_local_to_actor()

        actor = cls._get_or_create_actor()
        if actor is None:
            return

        available = ray.get(actor.list_available.remote())

        # Sync any new functions from actor to local registry
        for name in available:
            if name not in cls._functions:
                func_serialized = ray.get(actor.get.remote(name))
                if func_serialized is not None:
                    # Deserialize the function
                    try:
                        func = cloudpickle.loads(func_serialized)
                        cls._functions[name] = func
                    except Exception as e:
                        # If deserialization fails, skip this function
                        logger.error(f"Error deserializing {name} from actor: {e}")
                        raise e

method register

register(name: Union[str, StrEnum], func: Callable)

Register a function.

If ray is initialized, this function will get or create a named ray actor (RegistryActor) for the registry, and sync the registry to the actor.

If ray is not initalized, the function will be stored in the local registry only.

To make sure all locally registered functions are available to all ray processes, call sync_with_actor() after ray.init().

Parameters:

NameTypeDescriptionDefault
nameUnion[str, StrEnum]Name of the function to register. Can be a string or a StrEnum.required
funcCallableFunction to register.required

Raises:

TypeDescription
ValueErrorIf the function is already registered.
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:305-341
    @classmethod
    def register(cls, name: Union[str, StrEnum], func: Callable):
        """Register a function.

        If ray is initialized, this function will get or create a named ray actor (RegistryActor)
        for the registry, and sync the registry to the actor.

        If ray is not initalized, the function will be stored in the local registry only.

        To make sure all locally registered functions are available to all ray processes,
        call sync_with_actor() after ray.init().

        Args:
            name: Name of the function to register. Can be a string or a StrEnum.
            func: Function to register.

        Raises:
            ValueError: If the function is already registered.
        """
        # Convert enum to string if needed
        # note: StrEnum is not cloudpickleable: https://github.com/cloudpipe/cloudpickle/issues/558
        if isinstance(name, StrEnum):
            name = name.value

        if name in cls._functions:
            raise ValueError(f"{cls._function_type} '{name}' already registered")

        # Always store in local registry first
        cls._functions[name] = func

        # Try to sync with Ray actor if Ray is initialized
        if ray.is_initialized():
            actor = cls._get_or_create_actor()
            if actor is not None:
                # Serialize the function using cloudpickle
                func_serialized = cloudpickle.dumps(func)
                ray.get(actor.register.remote(name, func_serialized))

method classmethod get

get(name: str) -> Callable

Get a function by name.

If ray is initialized, this function will first sync the local registry with the RegistryActor. Then it will return the function if it is found in the registry.

Parameters:

NameTypeDescriptionDefault
namestrName of the function to get. Can be a string or a StrEnum.required

Returns:

TypeDescription
CallableThe function if it is found in the registry.
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:343-363
    @classmethod
    def get(cls, name: str) -> Callable:
        """Get a function by name.

        If ray is initialized, this function will first sync the local registry with the RegistryActor.
        Then it will return the function if it is found in the registry.

        Args:
            name: Name of the function to get. Can be a string or a StrEnum.

        Returns:
            The function if it is found in the registry.
        """
        # Try to sync with actor first if Ray is available
        if ray.is_initialized():
            cls.sync_with_actor()

        if name not in cls._functions:
            available = list(cls._functions.keys())
            raise ValueError(f"Unknown {cls._function_type.lower()} '{name}'. Available: {available}")
        return cls._functions[name]

method classmethod list_available

list_available() -> List[str]

List all registered functions.

Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:365-371
    @classmethod
    def list_available(cls) -> List[str]:
        """List all registered functions."""
        # Try to sync with actor first if Ray is available
        if ray.is_initialized():
            cls.sync_with_actor()
        return list(cls._functions.keys())

method classmethod unregister

unregister(name: Union[str, StrEnum])

Unregister a function. Useful for testing.

Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:373-404
    @classmethod
    def unregister(cls, name: Union[str, StrEnum]):
        """Unregister a function. Useful for testing."""
        # Convert enum to string if needed
        if isinstance(name, StrEnum):
            name = name.value

        # Try to sync with actor first to get any functions that might be in the actor but not local
        if ray.is_initialized():
            cls.sync_with_actor()

        # Track if we found the function anywhere
        found_locally = name in cls._functions
        found_in_actor = False

        # Remove from local registry if it exists
        if found_locally:
            del cls._functions[name]

        # Try to remove from Ray actor if Ray is available
        if ray.is_initialized():
            actor = cls._get_or_create_actor()
            if actor is not None:
                # Check if it exists in actor first
                available_in_actor = ray.get(actor.list_available.remote())
                if name in available_in_actor:
                    found_in_actor = True
                    ray.get(actor.unregister.remote(name))

        # Only raise error if the function wasn't found anywhere
        if not found_locally and not found_in_actor:
            raise ValueError(f"{cls._function_type} '{name}' not registered")

method classmethod reset

reset()

Resets the registry (useful for testing purposes).

Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:406-417
    @classmethod
    def reset(cls):
        """Resets the registry (useful for testing purposes)."""
        if ray.is_initialized() and cls._ray_actor is not None:
            try:
                actor = ray.get_actor(cls._actor_name)  # this raises exception if the actor is stale
                ray.kill(actor)
            except ValueError:
                pass  # Actor may already be gone
        cls._functions.clear()
        cls._ray_actor = None
        cls._synced_to_actor = False

method classmethod repopulate

repopulate()

Repopulate the registry with the default functions.

Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:419-423
    @classmethod
    def repopulate(cls):
        """Repopulate the registry with the default functions."""
        cls.reset()
        cls.register(cls._function_type, cls._function_type)

class RegistryActor

RegistryActor()

Shared Ray actor for managing function registries across processes.

Functions:

NameDescription
registerRegister a serialized function.
getGet a serialized function by name.
list_availableList all available function names.
unregisterUnregister a function by name.

Attributes:

NameTypeDescription
registry
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:198-219
@ray.remote
class RegistryActor:
    """Shared Ray actor for managing function registries across processes."""

    def __init__(self):
        self.registry = {}

    def register(self, name: str, func_serialized: bytes):
        """Register a serialized function."""
        self.registry[name] = func_serialized

    def get(self, name: str):
        """Get a serialized function by name."""
        return self.registry.get(name)

    def list_available(self):
        """List all available function names."""
        return list(self.registry.keys())

    def unregister(self, name: str):
        """Unregister a function by name."""
        return self.registry.pop(name, None)

attr registry

registry = {}

method register

register(name: str, func_serialized: bytes)

Register a serialized function.

Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:205-207
    def register(self, name: str, func_serialized: bytes):
        """Register a serialized function."""
        self.registry[name] = func_serialized

method classmethod get

get(name: str)

Get a serialized function by name.

Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:209-211
    def get(self, name: str):
        """Get a serialized function by name."""
        return self.registry.get(name)

method classmethod list_available

list_available()

List all available function names.

Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:213-215
    def list_available(self):
        """List all available function names."""
        return list(self.registry.keys())

method classmethod unregister

unregister(name: str)

Unregister a function by name.

Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:217-219
    def unregister(self, name: str):
        """Unregister a function by name."""
        return self.registry.pop(name, None)

method sync_registries

sync_registries()

Sync the registries with the ray actor once ray is initialized

Advantage Estimator Registry

The advantage estimator registry manages functions that compute advantages and returns.

class AdvantageEstimatorRegistry

Bases: BaseFunctionRegistry

Registry for advantage estimator functions.

This registry allows users to register custom advantage estimators without modifying the skyrl_train package. Custom estimators can be registered by calling AdvantageEstimatorRegistry.register() directly or by using the @register_advantage_estimator decorator.

See examples/algorithms/custom_advantage_estimator for a simple example of how to register and use custom advantage estimators.

Functions:

NameDescription
sync_with_actorSync local registry with Ray actor if Ray is available.
registerRegister a function.
getGet a function by name.
list_availableList all registered functions.
unregisterUnregister a function. Useful for testing.
resetResets the registry (useful for testing purposes).
repopulateRepopulate the registry with the default functions.
repopulate_registry
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:433-461
class AdvantageEstimatorRegistry(BaseFunctionRegistry):
    """
    Registry for advantage estimator functions.

    This registry allows users to register custom advantage estimators without modifying
    the skyrl_train package. Custom estimators can be registered by calling
    AdvantageEstimatorRegistry.register() directly or by using the @register_advantage_estimator
    decorator.

    See examples/algorithms/custom_advantage_estimator for a simple example of how to
    register and use custom advantage estimators.
    """

    _actor_name = "advantage_estimator_registry"
    _function_type = "advantage estimator"

    @classmethod
    def repopulate_registry(cls):
        ae_avail = set(cls.list_available())
        ae_types = {
            "grpo": [AdvantageEstimator.GRPO, compute_grpo_outcome_advantage],
            "gae": [AdvantageEstimator.GAE, compute_gae_advantage_return],
            "rloo": [AdvantageEstimator.RLOO, compute_rloo_outcome_advantage],
            "reinforce++": [AdvantageEstimator.REINFORCE_PP, compute_reinforce_plus_plus_outcome_advantage],
        }

        for ae_name, (ae_type, ae_func) in ae_types.items():
            if ae_name not in ae_avail:
                cls.register(ae_type, ae_func)

method classmethod sync_with_actor

sync_with_actor()

Sync local registry with Ray actor if Ray is available.

method register

register(name: Union[str, StrEnum], func: Callable)

Register a function.

If ray is initialized, this function will get or create a named ray actor (RegistryActor) for the registry, and sync the registry to the actor.

If ray is not initalized, the function will be stored in the local registry only.

To make sure all locally registered functions are available to all ray processes, call sync_with_actor() after ray.init().

Parameters:

NameTypeDescriptionDefault
nameUnion[str, StrEnum]Name of the function to register. Can be a string or a StrEnum.required
funcCallableFunction to register.required

Raises:

TypeDescription
ValueErrorIf the function is already registered.

method classmethod get

get(name: str) -> Callable

Get a function by name.

If ray is initialized, this function will first sync the local registry with the RegistryActor. Then it will return the function if it is found in the registry.

Parameters:

NameTypeDescriptionDefault
namestrName of the function to get. Can be a string or a StrEnum.required

Returns:

TypeDescription
CallableThe function if it is found in the registry.

method classmethod list_available

list_available() -> List[str]

List all registered functions.

method classmethod unregister

unregister(name: Union[str, StrEnum])

Unregister a function. Useful for testing.

method classmethod reset

reset()

Resets the registry (useful for testing purposes).

method classmethod repopulate

repopulate()

Repopulate the registry with the default functions.

method classmethod repopulate_registry

repopulate_registry()
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:449-461
    @classmethod
    def repopulate_registry(cls):
        ae_avail = set(cls.list_available())
        ae_types = {
            "grpo": [AdvantageEstimator.GRPO, compute_grpo_outcome_advantage],
            "gae": [AdvantageEstimator.GAE, compute_gae_advantage_return],
            "rloo": [AdvantageEstimator.RLOO, compute_rloo_outcome_advantage],
            "reinforce++": [AdvantageEstimator.REINFORCE_PP, compute_reinforce_plus_plus_outcome_advantage],
        }

        for ae_name, (ae_type, ae_func) in ae_types.items():
            if ae_name not in ae_avail:
                cls.register(ae_type, ae_func)

class AdvantageEstimator

Bases: StrEnum

Attributes:

NameTypeDescription
GAE
GRPO
RLOO
REINFORCE_PP
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:426-430
class AdvantageEstimator(StrEnum):
    GAE = "gae"
    GRPO = "grpo"
    RLOO = "rloo"
    REINFORCE_PP = "reinforce++"

attr GAE

GAE = 'gae'

attr GRPO

GRPO = 'grpo'

attr RLOO

RLOO = 'rloo'

attr REINFORCE_PP

REINFORCE_PP = 'reinforce++'

method register_advantage_estimator

register_advantage_estimator(name: Union[str, AdvantageEstimator])

Decorator to register an advantage estimator function.

Policy Loss Registry

The policy loss registry manages functions that compute policy losses for PPO.

class PolicyLossRegistry

Bases: BaseFunctionRegistry

Registry for policy loss functions.

This registry allows users to register custom policy loss functions without modifying the skyrl_train package. Custom functions can be registered by calling PolicyLossRegistry.register() directly or by using the @register_policy_loss decorator.

See examples/algorithms/custom_policy_loss for a simple example of how to register and use custom policy loss functions.

Functions:

NameDescription
sync_with_actorSync local registry with Ray actor if Ray is available.
registerRegister a function.
getGet a function by name.
list_availableList all registered functions.
unregisterUnregister a function. Useful for testing.
resetResets the registry (useful for testing purposes).
repopulateRepopulate the registry with the default functions.
repopulate_registryRepopulate the registry with default policy loss functions.
Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:476-509
class PolicyLossRegistry(BaseFunctionRegistry):
    """
    Registry for policy loss functions.

    This registry allows users to register custom policy loss functions without modifying
    the skyrl_train package. Custom functions can be registered by calling
    PolicyLossRegistry.register() directly or by using the @register_policy_loss
    decorator.

    See examples/algorithms/custom_policy_loss for a simple example of how to
    register and use custom policy loss functions.
    """

    _actor_name = "policy_loss_registry"
    _function_type = "policy loss"

    @classmethod
    def repopulate_registry(cls):
        """Repopulate the registry with default policy loss functions."""
        pl_avail = set(cls.list_available())
        pl_types = {
            "regular": [PolicyLossType.REGULAR, ppo_policy_loss],
            "dual_clip": [PolicyLossType.DUAL_CLIP, ppo_policy_loss],
            "gspo": [PolicyLossType.GSPO, gspo_policy_loss],
            "clip_cov": [PolicyLossType.CLIP_COV, compute_policy_loss_clip_cov],
            "kl_cov": [PolicyLossType.KL_COV, compute_policy_loss_kl_cov],
            "sapo": [PolicyLossType.SAPO, sapo_policy_loss],
            "cross_entropy": [PolicyLossType.CROSS_ENTROPY, cross_entropy_loss],
            "importance_sampling": [PolicyLossType.IMPORTANCE_SAMPLING, importance_sampling_loss],
        }

        for pl_name, (pl_type, pl_func) in pl_types.items():
            if pl_name not in pl_avail:
                cls.register(pl_type, pl_func)

method classmethod sync_with_actor

sync_with_actor()

Sync local registry with Ray actor if Ray is available.

method register

register(name: Union[str, StrEnum], func: Callable)

Register a function.

If ray is initialized, this function will get or create a named ray actor (RegistryActor) for the registry, and sync the registry to the actor.

If ray is not initalized, the function will be stored in the local registry only.

To make sure all locally registered functions are available to all ray processes, call sync_with_actor() after ray.init().

Parameters:

NameTypeDescriptionDefault
nameUnion[str, StrEnum]Name of the function to register. Can be a string or a StrEnum.required
funcCallableFunction to register.required

Raises:

TypeDescription
ValueErrorIf the function is already registered.

method classmethod get

get(name: str) -> Callable

Get a function by name.

If ray is initialized, this function will first sync the local registry with the RegistryActor. Then it will return the function if it is found in the registry.

Parameters:

NameTypeDescriptionDefault
namestrName of the function to get. Can be a string or a StrEnum.required

Returns:

TypeDescription
CallableThe function if it is found in the registry.

method classmethod list_available

list_available() -> List[str]

List all registered functions.

method classmethod unregister

unregister(name: Union[str, StrEnum])

Unregister a function. Useful for testing.

method classmethod reset

reset()

Resets the registry (useful for testing purposes).

method classmethod repopulate

repopulate()

Repopulate the registry with the default functions.

method classmethod repopulate_registry

repopulate_registry()

Repopulate the registry with default policy loss functions.

Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:492-509
    @classmethod
    def repopulate_registry(cls):
        """Repopulate the registry with default policy loss functions."""
        pl_avail = set(cls.list_available())
        pl_types = {
            "regular": [PolicyLossType.REGULAR, ppo_policy_loss],
            "dual_clip": [PolicyLossType.DUAL_CLIP, ppo_policy_loss],
            "gspo": [PolicyLossType.GSPO, gspo_policy_loss],
            "clip_cov": [PolicyLossType.CLIP_COV, compute_policy_loss_clip_cov],
            "kl_cov": [PolicyLossType.KL_COV, compute_policy_loss_kl_cov],
            "sapo": [PolicyLossType.SAPO, sapo_policy_loss],
            "cross_entropy": [PolicyLossType.CROSS_ENTROPY, cross_entropy_loss],
            "importance_sampling": [PolicyLossType.IMPORTANCE_SAMPLING, importance_sampling_loss],
        }

        for pl_name, (pl_type, pl_func) in pl_types.items():
            if pl_name not in pl_avail:
                cls.register(pl_type, pl_func)

class PolicyLossType

Bases: StrEnum

Attributes:

Source code in skyrl/backends/skyrl_train/utils/ppo_utils.py:464-473
class PolicyLossType(StrEnum):
    REGULAR = "regular"
    DUAL_CLIP = "dual_clip"
    GSPO = "gspo"
    CISPO = "cispo"
    CLIP_COV = "clip_cov"
    KL_COV = "kl_cov"
    SAPO = "sapo"
    CROSS_ENTROPY = "cross_entropy"
    IMPORTANCE_SAMPLING = "importance_sampling"

attr REGULAR

REGULAR = 'regular'

attr DUAL_CLIP

DUAL_CLIP = 'dual_clip'

attr GSPO

GSPO = 'gspo'

attr CISPO

CISPO = 'cispo'

attr CLIP_COV

CLIP_COV = 'clip_cov'

attr KL_COV

KL_COV = 'kl_cov'

attr SAPO

SAPO = 'sapo'

attr CROSS_ENTROPY

CROSS_ENTROPY = 'cross_entropy'

attr IMPORTANCE_SAMPLING

IMPORTANCE_SAMPLING = 'importance_sampling'

method register_policy_loss

register_policy_loss(name: Union[str, PolicyLossType])

Decorator to register a policy loss function.

On this page