Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MAINT: Refactor Skeleton Key to be a subclass of PromptSendingOrchest… #650

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 65 additions & 52 deletions pyrit/orchestrator/skeleton_key_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
PromptConverterConfiguration,
)
from pyrit.prompt_target import PromptTarget
from pyrit.orchestrator.single_turn.prompt_sending_orchestrator import PromptSendingOrchestrator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from pyrit.orchestrator.single_turn.prompt_sending_orchestrator import PromptSendingOrchestrator
from pyrit.orchestrator.single_turn import PromptSendingOrchestrator


logger = logging.getLogger(__name__)


class SkeletonKeyOrchestrator(Orchestrator):
class SkeletonKeyOrchestrator(PromptSendingOrchestrator):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recommend moving this file to pyrit/orchestrator/single_turn

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(and as part of that move, updating init.py

"""
Creates an orchestrator that executes a skeleton key jailbreak.

Expand Down Expand Up @@ -60,9 +61,12 @@ def __init__(
ensure proper rate limit management.
verbose (bool, Optional): If set to True, verbose output will be enabled. Defaults to False.
"""
super().__init__(prompt_converters=prompt_converters, verbose=verbose)

self._prompt_normalizer = PromptNormalizer()
super().__init__(
objective_target=prompt_target,
prompt_converters=prompt_converters,
batch_size=batch_size,
verbose=verbose
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pre-commit is failing. You can run it using pre-commit run --all-files; And you can see the current issues in the CI build: https://github.com/Azure/PyRIT/actions/runs/12876675701/job/35900136914?pr=650

self._skeleton_key_prompt = (
skeleton_key_prompt
Expand All @@ -74,80 +78,89 @@ def __init__(
.value
)

self._prompt_target = prompt_target

self._batch_size = batch_size

async def send_skeleton_key_with_prompt_async(
async def send_skeleton_key_with_prompts_async(
self,
*,
prompt: str,
) -> PromptRequestResponse:
prompt_list: list[str],
) -> list[PromptRequestResponse]:
"""
Sends a skeleton key, followed by the attack prompt to the target.

Args
Sends a skeleton key and prompt to the target for each prompt in a list of prompts.

prompt (str): The prompt to be sent.
prompt_type (PromptDataType, Optional): The type of the prompt (e.g., "text"). Defaults to "text".
Args:
prompt_list (list[str]): The list of prompts to be sent.

Returns:
PromptRequestResponse: The response from the prompt target.
list[PromptRequestResponse]: The responses from the prompt target.
"""
if hasattr(self._prompt_target, 'rpm') and self._prompt_target.rpm and self._batch_size != 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be removed - it's the responsibility of the target to check this

raise ValueError(
"When using a prompt target with max_requests_per_minute, batch_size must be set to 1"
)

# Create a single conversation ID for the entire sequence
conversation_id = str(uuid4())
metadata = {"conversation_id": conversation_id}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the conversation_id as metadata. PromptRequestPieces have this as an attribute and it can be set there


skeleton_key_prompt = SeedPromptGroup(prompts=[SeedPrompt(value=self._skeleton_key_prompt, data_type="text")])

converter_configuration = PromptConverterConfiguration(converters=self._prompt_converters)

await self._prompt_normalizer.send_prompt_async(
seed_prompt_group=skeleton_key_prompt,
conversation_id=conversation_id,
request_converter_configurations=[converter_configuration],
target=self._prompt_target,
labels=self._global_memory_labels,
orchestrator_identifier=self.get_identifier(),
# First, send all skeleton keys
skeleton_keys = [self._skeleton_key_prompt] * len(prompt_list)
await self.send_prompts_async(
prompt_list=skeleton_keys,
metadata=metadata
)

objective_prompt = SeedPromptGroup(prompts=[SeedPrompt(value=prompt, data_type="text")])

return await self._prompt_normalizer.send_prompt_async(
seed_prompt_group=objective_prompt,
conversation_id=conversation_id,
request_converter_configurations=[converter_configuration],
target=self._prompt_target,
labels=self._global_memory_labels,
orchestrator_identifier=self.get_identifier(),
# Then send all attack prompts with the same conversation ID
attack_responses = await self.send_prompts_async(
prompt_list=prompt_list,
metadata=metadata
)

return attack_responses

async def send_skeleton_key_with_prompts_async(
async def send_skeleton_key_with_prompt_async(
self,
*,
prompt_list: list[str],
) -> list[PromptRequestResponse]:
prompt: str,
) -> PromptRequestResponse:
"""
Sends a skeleton key and prompt to the target for each prompt in a list of prompts.
Sends a skeleton key, followed by the attack prompt to the target.

Args:
prompt_list (list[str]): The list of prompts to be sent.
prompt_type (PromptDataType, Optional): The type of the prompts (e.g., "text"). Defaults to "text".
prompt (str): The prompt to be sent.

Returns:
list[PromptRequestResponse]: The responses from the prompt target.
PromptRequestResponse: The response from the prompt target.
"""
Copy link
Contributor

@rlundeen2 rlundeen2 Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of doing it like how it was done, I would make use of PromptSendingOrchestrator.

  1. Use the set_prepended_conversation to set the skeleton key. This can be done in init
  2. You can remove the send_skeleton_key_with_prompt_async function. Sending any prompt will then have the skeleton key prepended

# Create a single conversation ID
conversation_id = str(uuid4())
metadata = {"conversation_id": conversation_id}

# Create normalizer requests for both prompts
skeleton_request = self._create_normalizer_request(
prompt_text=self._skeleton_key_prompt,
prompt_type="text",
converters=self._prompt_converters,
metadata=metadata,
conversation_id=conversation_id,
)

return await batch_task_async(
task_func=self.send_skeleton_key_with_prompt_async,
task_arguments=["prompt"],
prompt_target=self._prompt_target,
batch_size=self._batch_size,
items_to_batch=[prompt_list],
attack_request = self._create_normalizer_request(
prompt_text=prompt,
prompt_type="text",
converters=self._prompt_converters,
metadata=metadata,
conversation_id=conversation_id,
)

# Send both requests in a single batch
responses = await self.send_normalizer_requests_async(
prompt_request_list=[skeleton_request, attack_request]
)

# Return the attack prompt response (second response)
return responses[1]

def print_conversation(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should be able to remove print_conversation and use the parent class

"""Prints all the conversations that have occured with the prompt target."""

target_messages = self.get_memory()

if not target_messages or len(target_messages) == 0:
Expand All @@ -158,4 +171,4 @@ def print_conversation(self) -> None:
if message.role == "user":
print(f"{Style.BRIGHT}{Fore.RED}{message.role}: {message.converted_value}\n")
else:
print(f"{Style.BRIGHT}{Fore.GREEN}{message.role}: {message.converted_value}\n")
print(f"{Style.BRIGHT}{Fore.GREEN}{message.role}: {message.converted_value}\n")
Loading