Google Gen AI 5-Day Intensive: Day Four – Part 1 (4/5)

Codelab #1 – Tune A Gemini Model

This is the first assigned codelab on day four of the intensive. Download it here from Github to run locally or run in this Kaggle notebook.

"""Tune Gemini Model for Custom Function

Google Gen AI 5-Day Intensive Course
Host: Kaggle

Day: 4

Codelab: https://www.kaggle.com/code/markishere/day-4-fine-tuning-a-custom-model
"""

import datetime
import email
import os
import re
import time
import warnings
from collections.abc import Iterable

import pandas as pd
import tqdm
from google import genai
from google.api_core import retry
from google.genai import types
from sklearn.datasets import fetch_20newsgroups
from tqdm.rich import tqdm as tqdmr

client = genai.Client(api_key=os.environ["GOOGLE_API_KEY"])

for model in client.models.list():
    if "createTunedModel" in model.supported_actions:
        print(model.name)
        
newgroups_train = fetch_20newsgroups(subset='train')
newgroups_test = fetch_20newsgroups(subset='test')

# View list of class names for dataset
newsgroups_train.target_names
print(newsgroups_train.date[0])

def preprocess_newsgroup_row(data):
    # Extract only the subject and body.
    msg = email.message_from_string(data)
    text = f'{msg["Subject"]}\n\n{msg.get_payload()}'
    # Strip any remaining email addresses
    text = re.sub(r"[\w\.-]+@[\w\.-]+", "", text)
    # Truncate the text to fit within the input limits
    text = text[:40000]
    
    return text
    
def preprocess_newsgroup_data(newsgroup_dataset):
    # Put the points into a DataFrame
    df = pd.DataFrame(
        {
            'Text': newsgroup_dataset.data,
            'Label': newsgroup_dataset.target
        }
    )
    #  Clean up the text
    df['Text'] = df['Text'].apply(preprocess_newsgroup_row)
    # Match label to target name index
    df['Class Name'] = df['Label'].map(lambda l: newsgroup_dataset.target_names[l])
    
    return df

# Apply preprocessing to training and test datasets
df_train = preprocess_newsgroup_data(newgroups_train)
df_test = preprocess_newsgroup_data(newgroups_test)

df_train.head()

def sample_data(df, num_samples, classes_to_keep):
    # Sample rows, selecting num_samples of each label.
    df = (
        df.groupby('Label')[df.columns]
        .apply(lambda x: x.sample(num_samples))
        .reset_index(drop=True)
    )
    
    df = df[df['Class Name'].str.contains(classes_to_keep)]
    df['Class Name'] = df['Class Name'].astype('category')
    
    return df

TRAIN_NUM_SAMPLES = 50
TEST_NUM_SAMPLES = 10
# Keep rec.* and sci.*
CLASSES_TO_KEEP = '^rec|^sci'

df_train = sample_data(df_train, TRAIN_NUM_SAMPLES, CLASSES_TO_KEEP)
df_test = sample_data(df_test, TEST_NUM_SAMPLES, CLASSES_TO_KEEP)

# Evaluate baseline performance
sample_idx = 0
sample_row = preprocess_newsgroup_row(newsgroups_test.data[sample_idx])
sample_label = newsgroups_test.target_names[newsgroups_test.target[sample_idx]]

print(sample_row)
print('---')
print('Label:', sample_label)

response = client.models.generate_content(
    model='gemini-1.5-flash-001',
    contents=sample_row
)
print(response.text)


# Ask the model directly in a zero-shot prompt.

prompt = "From what newsgroup does the following message originate?"
baseline_response = client.models.generate_content(
    model="gemini-1.5-flash-001",
    contents=[prompt, sample_row])
print(baseline_response.text)


# You can use a system instruction to do more direct prompting, and get a
# more succinct answer.

system_instruct = """
You are a classification service. You will be passed input that represents
a newsgroup post and you must respond with the newsgroup from which the post
originates.
"""

# Define a helper to retry when per-minute quota is reached.
is_retriable = lambda e: (isinstance(e, genai.errors.APIError) and e.code in {429, 503})

# If you want to evaluate your own technique, replace this body of this function
# with your model, prompt and other code and return the predicted answer.
@retry.Retry(predicate=is_retriable)
def predict_label(post: str) -> str:
    response = client.models.generate_content(
        model="gemini-1.5-flash-001",
        config=types.GenerateContentConfig(
            system_instruction=system_instruct),
        contents=post)

    rc = response.candidates[0]

    # Any errors, filters, recitation, etc we can mark as a general error
    if rc.finish_reason.name != "STOP":
        return "(error)"
    else:
        # Clean up the response.
        return response.text.strip()


prediction = predict_label(sample_row)

print(prediction)
print()
print("Correct!" if prediction == sample_label else "Incorrect.")


# Enable tqdm features on Pandas.
tqdmr.pandas()

# But suppress the experimental warning
warnings.filterwarnings("ignore", category=tqdm.TqdmExperimentalWarning)


# Further sample the test data to be mindful of the free-tier quota.
df_baseline_eval = sample_data(df_test, 2, '.*')

# Make predictions using the sampled data.
df_baseline_eval['Prediction'] = df_baseline_eval['Text'].progress_apply(predict_label)

# And calculate the accuracy.
accuracy = (df_baseline_eval["Class Name"] == df_baseline_eval["Prediction"]).sum() / len(df_baseline_eval)
print(f"Accuracy: {accuracy:.2%}")


# Tune a custom model
# Convert the data frame into a dataset suitable for tuning.
input_data = {'examples': 
    df_train[['Text', 'Class Name']]
      .rename(columns={'Text': 'textInput', 'Class Name': 'output'})
      .to_dict(orient='records')
 }

# If you are re-running this lab, add your model_id here.
model_id = None

# Or try and find a recent tuning job.
if not model_id:
  queued_model = None
  # Newest models first.
  for m in reversed(client.tunings.list()):
    # Only look at newsgroup classification models.
    if m.name.startswith('tunedModels/newsgroup-classification-model'):
      # If there is a completed model, use the first (newest) one.
      if m.state.name == 'JOB_STATE_SUCCEEDED':
        model_id = m.name
        print('Found existing tuned model to reuse.')
        break

      elif m.state.name == 'JOB_STATE_RUNNING' and not queued_model:
        # If there's a model still queued, remember the most recent one.
        queued_model = m.name
else:
    if queued_model:
        model_id = queued_model
        print('Found queued model, still waiting.')


# Upload the training data and queue the tuning job.
if not model_id:
    tuning_op = client.tunings.tune(
        base_model="models/gemini-1.5-flash-001-tuning",
        training_dataset=input_data,
        config=types.CreateTuningJobConfig(
            tuned_model_display_name="Newsgroup classification model",
            batch_size=16,
            epoch_count=2,
        ),
    )

    print(tuning_op.state)
    model_id = tuning_op.name

print(model_id)


MAX_WAIT = datetime.timedelta(minutes=10)

while not (tuned_model := client.tunings.get(name=model_id)).has_ended:

    print(tuned_model.state)
    time.sleep(60)

    # Don't wait too long. Use a public model if this is going to take a while.
    if datetime.datetime.now(datetime.timezone.utc) - tuned_model.create_time > MAX_WAIT:
        print("Taking a shortcut, using a previously prepared model.")
        model_id = "tunedModels/newsgroup-classification-model-ltenbi1b"
        tuned_model = client.tunings.get(name=model_id)
        break


print(f"Done! The model state is: {tuned_model.state.name}")

if not tuned_model.has_succeeded and tuned_model.error:
    print("Error:", tuned_model.error)
    

#  Use the new model
new_text = """
First-timer looking to get out of here.

Hi, I'm writing about my interest in travelling to the outer limits!

What kind of craft can I buy? What is easiest to access from this 3rd rock?

Let me know how to do that please.
"""

response = client.models.generate_content(
    model=model_id, contents=new_text)

print(response.text)


@retry.Retry(predicate=is_retriable)
def classify_text(text: str) -> str:
    """Classify the provided text into a known newsgroup."""
    response = client.models.generate_content(
        model=model_id, 
        contents=text)
    rc = response.candidates[0]

    # Any errors, filters, recitation, etc we can mark as a general error
    if rc.finish_reason.name != "STOP":
        return "(error)"
    else:
        return rc.content.parts[0].text


# The sampling here is just to minimise your quota usage. If you can, you should
# evaluate the whole test set with `df_model_eval = df_test.copy()`.
df_model_eval = sample_data(df_test, 4, '.*')

df_model_eval["Prediction"] = df_model_eval["Text"].progress_apply(classify_text)

accuracy = (df_model_eval["Class Name"] == df_model_eval["Prediction"]).sum() / len(df_model_eval)
print(f"Accuracy: {accuracy:.2%}")


# Compare token usage
# Calculate the *input* cost of the baseline model with system instructions.
sysint_tokens = client.models.count_tokens(
    model='gemini-1.5-flash-001', contents=[system_instruct, sample_row]
).total_tokens
print(f'System instructed baseline model: {sysint_tokens} (input)')

# Calculate the input cost of the tuned model.
tuned_tokens = client.models.count_tokens(model=tuned_model.base_model, contents=sample_row).total_tokens
print(f'Tuned model: {tuned_tokens} (input)')

savings = (sysint_tokens - tuned_tokens) / tuned_tokens
print(f'Token savings: {savings:.2%}')  # Note that this is only n=1.


# Tweak output token quantity
baseline_token_output = baseline_response.usage_metadata.candidates_token_count
print('Baseline (verbose) output tokens:', baseline_token_output)

tuned_model_output = client.models.generate_content(
    model=model_id, contents=sample_row)
tuned_tokens_output = tuned_model_output.usage_metadata.candidates_token_count
print('Tuned output tokens:', tuned_tokens_output)

Leave a Reply

Your email address will not be published. Required fields are marked *