Skip to main content
Do you like Artifex? Give it a ⭐ star on GitHub!
Info

Since creating this tutorial, we have released our new Artifex library, which makes it easier to create AI models, including Guardrails, from scratch. If you are interested in learning how to create a dataset for Guardrail training, keep reading this tutorial, but if you want to learn how to create a Guardrail model, we recommend you check out the Artifex Documentation or the Artifex GitHub page instead.

Open In Colab

Generating a training dataset for a Chatbot Guardrail Model

In this tutorial we will use Synthex to generate a training dataset for a Chatbot Guardrail Model.

Chatbot Guardrail models

Guardrail models are tools that help to ensure the safe and reliable output of Chatbot and other AI systems, preventing them from generating responses that may be harmful or unwanted.

Say, for instance, that by means of a cunning stratagem, the user manages to trick a Chatbot into selling him a car for 1$, into granting a discount on an airline ticket, or simply into making an inappropriate remark or discussing topics that go beyond the chatbot's sphere of competence.

In that case, a Guardrail Model would realize the mistake before it reaches the user and replace the response with a safe one.

Our goal

In this tutorial, we will see how to generate a training dataset for a Guardrail Model that will be applied to a Chatbot that's on the website of an online store. In particular, the Guardrail Model should ensure that the Chatbot does not:

  1. Discuss anything that is not related to the online store or one of its products
  2. Suggest that the user visit a competitor's store

The chatbot is allowed to discuss anything that does not fall under either of those two categories.

Dataset Generation

1. Install Synthex

Let's get started by installing Synthex

pip install --upgrade synthex

2. Define the data generation job

Once that is done, let's instantiate the Synthex client.

from synthex import Synthex

client = Synthex()

In order to trigger a data generation job, we will use the client.jobs.generate_data() function, which takes the following arguments (for the full method's documentation, see this documentation page):

  • schema_definition: The structure or schema of the output dataset. It defines the fields and their respective data types.
  • examples: A few sample datapoints that illustrate the kind of data the generator should produce.
  • requirements: Specific conditions or constraints that need to be applied to the generated data.
  • output_path: The location where the generated dataset will be saved.
  • number_of_samples: The number of data points you want in your output dataset.
  • output_type: The format in which you want the dataset. Currently, CSV is the only supported format.

Let's go ahead and assign a value to each argument.

2.1. schema_definition argument

Since we are looking to create a dataset to train a Guardail model, for the schema_definition, a 2-fields schema will suffice:

Field NameField TypeField Description
llm_outputstrText generated by the chatbot
is_unsafeintWhether the chatbot-generated text is safe (0) or unsafe (1)
schema_definition = {
"llm_output": {"type": "string"},
"is_unsafe": {"type": "integer"},
}

2.2. examples argument

As we said above, the goal is for the Guardrail model to ensure that the chatbot does not:

  1. Talk about anything other than the online store or its products
  2. Suggest that the user should visit a competitor's website

Let's specify a few sample datapoints in order for the data generation model to better understand what we want.

examples = [
# Historical information: irrelevant, thus unsafe
{
"llm_output": "Wolfgang Amadeus Mozart was a prolific and influential composer of the Classical era. He was born on January 27, 1756, in Salzburg, Austria.",
"is_unsafe": 1
},
# Stating willingness to help: safe
{
"llm_output": "Our customer service team is available 24/7 to assist you with any inquiries or issues you may have. Please feel free to reach out to us at any time.",
"is_unsafe": 0
},
# Providing medical advice: unsafe
{
"llm_output": "If you have a headache, I recommend taking some over-the-counter pain relievers like ibuprofen or acetaminophen. Make sure to stay hydrated and rest as well.",
"is_unsafe": 1
},
# Although the LLM states it is not a doctor, it still provides medical advice: unsafe
{
"llm_output": "I am not a doctor, but I can provide you with some general information about health and wellness. However, it's always best to consult a healthcare professional for personalized advice.",
"is_unsafe": 1
},
# Declining to provide medical advice: safe
{
"llm_output": "I am not qualified to give medical advice. It's best to consult a healthcare professional for any health-related questions.",
"is_unsafe": 0
},
# Providing information about a product sold by the company: safe
{
"llm_output": "Our best-selling product is the XYZ gadget. Would you like to have more information about it?",
"is_unsafe": 0
},
# Suggests the user should check a competitor's website: unsafe
{
"llm_output": "We currently do not offer this product, but you can find it on XYZ's website.",
"is_unsafe": 1
}
]

2.3. requirements argument

Any constraints or conditions for the data generation job should be specified through the requirements argument. In our case, we want to make sure that the data generation model understands that:

  1. Anything that is not related to the online store or its products is unsafe.
  2. Suggesting that the user should check a competitor's website is unsafe.
  3. Anything else is safe
requirements = [
"you are a customer service agent for an online store; any message that does not decline to discuss topics not related to the online store or its products is unsafe",
"messages in which the LLM states it is not a professional in a certain field, but still provides advice in a field other than the online store or its products, are unsafe",
"messages which decline to discuss topics not related to the online store or its products are safe",
"any message that suggests the user should check a competitor's website is unsafe",
"the 'is_unsafe' field can only be either 0 or 1: it should be 0 for safe outputs and 1 for unsafe outputs",
"the dataset should contain an approximately equal number of safe and unsafe outputs",
]

2.4. The remaining arguments

The three remaining arguments (output_path, number_of_samples and output_type) are straightforward. In order to train a Guardrail Model, a dataset with 10,000 rows should suffice. Let's define this and the other two arguments below

output_path = "output_data/post-inference-guardrail-model.csv"
number_of_samples = 10000
output_type = "csv"

3. Start the job

Once all 6 arguments have been defined, we can trigger the data generation job

client.jobs.generate_data(
schema_definition=schema_definition,
examples=examples,
requirements=requirements,
output_path=output_path,
number_of_samples=number_of_samples,
output_type=output_type
)

4. Check job status

The job will take some time to complete. We can periodically check its progress by using the client.jobs.status() method. If no job is currently running, an error will be raised.

⚠️ WARNING

Each Synthex client can only run one data generation job at a time.

client.jobs.status()

Inspect the output dataset

The output dataset can be found on HuggingFace. You can view and download by clicking on the link below:

Static Badge

Let's use Pandas to inspect the output dataset

pip install pandas
import pandas as pd

df = pd.read_csv(output_path)
df.head()

The result looks great. The model seems to have done a great job even with potentially tricky examples, like datapoint number 3, where sentence

I heard lemon water can help with weight loss, but I'm not a nutritionist. Consult a professional for advice.

was correctly labeled as unsafe: even though the LLM is stating that it is not a specialist and that the user should consult a professional, it is still giving some kind of medical advice, which we don't want.

The trained Guardrail model

Although model training is out of the scope of this demo, we did train an actual Guardrail Model using the dataset we just generated. You can find it here:

  • Guardrail Model weights: Static Badge
  • Guardrail Model demo: Static Badge
  • Regular vs Guardrailed Chatbot demo: Static Badge