Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
359 changes: 359 additions & 0 deletions notebooks/train-yolo26-pose-estimation-on-custom-dataset.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,359 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<div align=\"center\">\n",
" <a href=\"https://roboflow.com\" target=\"_blank\">\n",
" <img\n",
" width=\"100%\"\n",
" src=\"https://media.roboflow.com/notebooks/template/bannerformats-dark.png?ik-sdk-version=javascript-1.4.3&updatedAt=1672949716562\"\n",
" >\n",
" </a>\n",
"</div>\n",
"\n",
"# Fine-Tune YOLO26 on Pose Estimation Dataset\n",
"\n",
"---\n",
"\n",
"[![Roboflow](https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg)](https://roboflow.com)\n",
"[![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtube.com/roboflow)\n",
"[![GitHub](https://badges.aleen42.com/src/github.svg)](https://github.com/ultralytics/ultralytics)\n",
"\n",
"YOLO26 introduces **NMS-free end-to-end inference** and the **MuSGD optimizer**, making it faster and more accurate than previous YOLO versions. This notebook shows you how to fine-tune YOLO26 on a custom pose estimation dataset using [Roboflow](https://roboflow.com) for dataset management.\n",
"\n",
"**What you'll learn:**\n",
"- How to prepare a pose estimation dataset with Roboflow\n",
"- How to fine-tune YOLO26 on a custom keypoint dataset\n",
"- How to run inference and visualize keypoints\n",
"- How to export the model for deployment\n",
"\n",
"**Before you start**, make sure you have access to a GPU. You can use a **free T4 GPU** by navigating to `Runtime` → `Change runtime type` → `T4 GPU`."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 1: Install Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"HOME = os.getcwd()\n",
"print(\"HOME:\", HOME)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install ultralytics supervision roboflow -q\n",
"\n",
"import ultralytics\n",
"ultralytics.checks()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 2: Load Pretrained YOLO26 Pose Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ultralytics import YOLO\n",
"\n",
"# Load YOLO26 nano pose model pretrained on COCO\n",
"model = YOLO(\"yolo26n-pose.pt\")\n",
"print(\"Model loaded successfully!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 3: Download Dataset from Roboflow\n",
"\n",
"We'll use a publicly available pose estimation dataset from Roboflow Universe. You can replace this with your own dataset.\n",
"\n",
"**To use your own dataset:**\n",
"1. Go to [Roboflow Universe](https://universe.roboflow.com)\n",
"2. Find or upload your pose estimation dataset\n",
"3. Export in **YOLOv8 Pose** format\n",
"4. Replace the snippet below with your own export code"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from roboflow import Roboflow\n",
"\n",
"rf = Roboflow(api_key=\"YOUR_ROBOFLOW_API_KEY\") # Replace with your Roboflow API key\n",
"project = rf.workspace(\"roboflow-jvuqo\").project(\"football-players-detection-3zvbc\")\n",
"version = project.version(1)\n",
"dataset = version.download(\"yolov8-pose\")\n",
"\n",
"print(f\"Dataset downloaded to: {dataset.location}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 4: Fine-Tune YOLO26 on Pose Estimation Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ultralytics import YOLO\n",
"\n",
"model = YOLO(\"yolo26n-pose.pt\")\n",
"\n",
"results = model.train(\n",
" data=f\"{dataset.location}/data.yaml\",\n",
" epochs=100,\n",
" imgsz=640,\n",
" batch=16,\n",
" device=0,\n",
" project=\"pose_training\",\n",
" name=\"yolo26n_pose\",\n",
" exist_ok=True,\n",
" patience=20,\n",
" save=True,\n",
" plots=True,\n",
")\n",
"\n",
"print(f\"\\nTraining complete!\")\n",
"print(f\"Best model saved to: {results.save_dir}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 5: Evaluate the Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ultralytics import YOLO\n",
"\n",
"# Load best trained model\n",
"model = YOLO(\"pose_training/yolo26n_pose/weights/best.pt\")\n",
"\n",
"# Run validation\n",
"metrics = model.val(\n",
" data=f\"{dataset.location}/data.yaml\",\n",
" device=0,\n",
")\n",
"\n",
"print(f\"Pose mAP50: {metrics.pose.map50:.4f}\")\n",
"print(f\"Pose mAP50-95: {metrics.pose.map:.4f}\")\n",
"print(f\"Box mAP50: {metrics.box.map50:.4f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 6: Run Inference and Visualize Keypoints"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import glob\n",
"import random\n",
"from IPython.display import display, Image as IPImage\n",
"from ultralytics import YOLO\n",
"\n",
"model = YOLO(\"pose_training/yolo26n_pose/weights/best.pt\")\n",
"\n",
"# Pick a random validation image\n",
"val_images = glob.glob(f\"{dataset.location}/valid/images/*.jpg\")\n",
"test_image = random.choice(val_images)\n",
"\n",
"results = model.predict(\n",
" source=test_image,\n",
" conf=0.35,\n",
" save=True,\n",
" project=\"pose_inference\",\n",
" name=\"yolo26_pose_results\",\n",
" exist_ok=True,\n",
")\n",
"\n",
"# Display result\n",
"output_image = glob.glob(\"pose_inference/yolo26_pose_results/*.jpg\")[0]\n",
"display(IPImage(output_image, width=800))\n",
"\n",
"# Print keypoint details\n",
"for r in results:\n",
" if r.keypoints is not None:\n",
" print(f\"Detected {len(r.keypoints)} person(s)\")\n",
" print(f\"Keypoints shape: {r.keypoints.xy.shape}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 7: Visualize with Supervision"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import supervision as sv\n",
"import cv2\n",
"from ultralytics import YOLO\n",
"from IPython.display import display, Image as IPImage\n",
"import tempfile, os\n",
"\n",
"model = YOLO(\"pose_training/yolo26n_pose/weights/best.pt\")\n",
"\n",
"image = cv2.imread(test_image)\n",
"results = model(image, conf=0.35)[0]\n",
"\n",
"# Extract detections and keypoints\n",
"keypoints = sv.KeyPoints.from_ultralytics(results)\n",
"detections = sv.Detections.from_ultralytics(results)\n",
"\n",
"# Annotate\n",
"box_annotator = sv.BoxAnnotator()\n",
"edge_annotator = sv.EdgeAnnotator(color=sv.Color.GREEN, thickness=2)\n",
"vertex_annotator = sv.VertexAnnotator(color=sv.Color.RED, radius=4)\n",
"\n",
"annotated = box_annotator.annotate(image.copy(), detections)\n",
"annotated = edge_annotator.annotate(annotated, keypoints)\n",
"annotated = vertex_annotator.annotate(annotated, keypoints)\n",
"\n",
"# Save and display\n",
"output_path = \"pose_supervision_output.jpg\"\n",
"cv2.imwrite(output_path, annotated)\n",
"display(IPImage(output_path, width=800))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 8: Export for Deployment"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ultralytics import YOLO\n",
"\n",
"model = YOLO(\"pose_training/yolo26n_pose/weights/best.pt\")\n",
"\n",
"# Export to ONNX\n",
"model.export(format=\"onnx\", imgsz=640, simplify=True)\n",
"print(\"Exported to ONNX\")\n",
"\n",
"# Export to TFLite (mobile devices)\n",
"# model.export(format=\"tflite\", imgsz=640, int8=True)\n",
"\n",
"# Export to TensorRT (NVIDIA GPUs)\n",
"# model.export(format=\"engine\", imgsz=640, half=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## What's Next?\n",
"\n",
"- 📖 [YOLO26 Docs](https://docs.ultralytics.com/models/yolo26)\n",
"- 🔍 [Roboflow Universe — Pose Datasets](https://universe.roboflow.com/search?q=pose+estimation)\n",
"- 💜 [Supervision Docs](https://supervision.roboflow.com)\n",
"- 🐛 [Report Issues](https://github.com/roboflow/notebooks/issues)\n",
"\n",
"<div align=\"center\">\n",
" <div>\n",
" <a href=\"https://youtube.com/roboflow\">\n",
" <img\n",
" src=\"https://media.roboflow.com/notebooks/template/icons/purple/youtube.png?ik-sdk-version=javascript-1.4.3&updatedAt=1672949748453\"\n",
" height=\"110\"\n",
" >\n",
" </a>\n",
" <a href=\"https://roboflow.com\">\n",
" <img\n",
" src=\"https://media.roboflow.com/notebooks/template/icons/purple/roboflow-app.png?ik-sdk-version=javascript-1.4.3&updatedAt=1672949748453\"\n",
" height=\"110\"\n",
" >\n",
" </a>\n",
" <a href=\"https://www.linkedin.com/company/roboflow-ai/\">\n",
" <img\n",
" src=\"https://media.roboflow.com/notebooks/template/icons/purple/linkedin.png?ik-sdk-version=javascript-1.4.3&updatedAt=1672949748453\"\n",
" height=\"110\"\n",
" >\n",
" </a>\n",
" <a href=\"https://roboflow.com/twitter\">\n",
" <img\n",
" src=\"https://media.roboflow.com/notebooks/template/icons/purple/twitter.png?ik-sdk-version=javascript-1.4.3&updatedAt=1672949748453\"\n",
" height=\"110\"\n",
" >\n",
" </a>\n",
" <a href=\"https://discord.gg/roboflow\">\n",
" <img\n",
" src=\"https://media.roboflow.com/notebooks/template/icons/purple/discord.png?ik-sdk-version=javascript-1.4.3&updatedAt=1672949748453\"\n",
" height=\"110\"\n",
" >\n",
" </a>\n",
" </div>\n",
"</div>"
]
}
]
}