Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions conf/finetune/gspo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defaults:
- base
- _self_

attempts: 8
rl:
policy_loss: gspo
epsilon_high: 4e-4
epsilon_low: 3e-4
47 changes: 47 additions & 0 deletions conf/swe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
defaults:
- base
- _self_
- override finetune: gspo

model_path: Qwen/Qwen3-8B

actor:
rollout_policy: pipelinerl.domains.swe.rollouts.generate_swe_rollout
success_threshold: 0.8

environments: null

dataset_loader: pipelinerl.domains.swe.load_datasets.load_local_swe_dataset
dataset_loader_params:
seed: ${seed}
# max_samples: 1000 # uncomment to cap the number of loaded samples (applies to both train and test)

# HuggingFace Hub dataset IDs (or local disk paths).
# Append ":split" to restrict to a specific split, e.g. SWE-bench/SWE-smith-py:train
train_dataset_names:
- SWE-bench/SWE-smith-py
test_dataset_names:
- SWE-bench/SWE-smith-py
Comment thread
ehsk marked this conversation as resolved.
Outdated

finetune:
seq_length: 24000
rl:
filter_zero_advantage_groups: false

vllm_config:
vllm_kwargs:
max_model_len: 24000

llm:
parameters:
max_tokens: 4096
temperature: 1.0
chat_template_kwargs:
enable_thinking: false

test_llm:
parameters:
max_tokens: 4096
temperature: 0.0
chat_template_kwargs:
enable_thinking: false
15 changes: 15 additions & 0 deletions conf/swe/preprocess.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Config for swe_preprocessor.py.
# Clones repos, extracts gold_file_contents at base_commit, applies token
# filtering, and saves a training-ready HuggingFace disk dataset.
#
# Run: python -m pipelinerl.domains.swe.swe_preprocessor --config-name=swe/preprocess

hf_dataset_name: SWE-bench/SWE-smith-py
hf_split_name: train
repo_path: /path/to/repos
dataset_path: /path/to/output_ds
tokenizer_model: Qwen/Qwen3-8B
min_token_threshold: null # set to an int to filter out very short examples
max_token_threshold: 16000 # set to null to disable
num_map_processes: 32
force_reprocess: false
127 changes: 127 additions & 0 deletions pipelinerl/domains/swe/load_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Supported datasets
# ──────────────────────────────────────────────────────────────────────────────
# Ready to use (have gold_file_contents pre-extracted):
# SWE-bench/SWE-smith-py local preprocessed disk dataset or Hub ID
# SWE-bench/SWE-smith-java "
# SWE-bench/SWE-smith-rs "
# SWE-bench/SWE-smith-go "
#
# Require preprocessing first (clone repos, extract file contents at base_commit):
# princeton-nlp/SWE-bench
# princeton-nlp/SWE-bench_Lite
# princeton-nlp/SWE-bench_Verified
# SWE-bench/SWE-Pro (if/when released publicly)
#
# Run: python -m pipelinerl.domains.swe.swe_preprocessor --config-name=swe/preprocess
# ──────────────────────────────────────────────────────────────────────────────

import json
import logging
import os
import random
from typing import Any, Dict, List, Optional

from datasets import load_dataset, load_from_disk

logger = logging.getLogger(__name__)


def _parse_file_contents(raw: Any) -> Dict[str, str]:
if isinstance(raw, dict):
return {str(k): str(v) for k, v in raw.items()}
if isinstance(raw, str):
try:
parsed = json.loads(raw)
except (json.JSONDecodeError, TypeError):
return {}
if isinstance(parsed, dict):
return {str(k): str(v) for k, v in parsed.items()}
return {}


def _load_single_dataset(path: str) -> List[Dict]:
"""Load a dataset from a local disk path or a HuggingFace Hub ID.

Local path: /path/to/ds_train
Hub ID: SWE-bench/SWE-smith-py (all splits concatenated)
Hub ID+split: SWE-bench/SWE-smith-py:train
"""
if os.path.exists(path):
logger.info("Loading from disk: %s", path)
dataset = load_from_disk(path)
else:
# Hub ID, optionally with ":split" suffix
if ":" in path:
hub_id, split = path.rsplit(":", 1)
else:
hub_id, split = path, None

logger.info("Loading from HuggingFace Hub: %s (split=%s)", hub_id, split or "all")
loaded = load_dataset(hub_id, split=split)

if split is None:
# DatasetDict — concatenate all splits
from datasets import concatenate_datasets
dataset = concatenate_datasets(list(loaded.values()))
else:
dataset = loaded

logger.info("Loaded %d rows from %s", len(dataset), path)

samples = []
for row in dataset:
item = dict(row)
try:
file_contents = _parse_file_contents(item.get("gold_file_contents", "{}"))
if not file_contents:
continue
samples.append({
"id": item.get("id", "") or item.get("instance_id", "") or item.get("issue_id", ""),
"dataset": item.get("dataset", "") or path,
"repo": item.get("repo", ""),
"base_commit": item.get("base_commit", ""),
"problem_statement": item.get("problem_statement", ""),
"patch": item.get("patch", ""),
"file_contents": file_contents,
})
except Exception as e:
logger.warning("Skipping malformed item: %s", e)

return samples


def load_local_swe_dataset(
dataset_paths: List[str],
seed: int = 42,
max_samples: Optional[int] = None,
) -> List[Dict]:
"""
Load one or more SWE-style datasets from disk and return a combined, shuffled list.

Args:
dataset_paths: Passed via cfg.train_dataset_names / cfg.test_dataset_names.
Each entry is a filesystem path to a HuggingFace disk dataset.
Add multiple paths to mix datasets (e.g. swe-smith + swe-bench).
seed: Random seed for shuffling (inherit from cfg.seed via dataset_loader_params).
max_samples: Optional cap on the total number of returned samples.
"""
if not dataset_paths:
logger.error("No dataset paths provided")
return []

all_samples: List[Dict] = []
for path in dataset_paths:
try:
all_samples.extend(_load_single_dataset(path))
except Exception as e:
logger.error("Failed to load dataset from %s: %s", path, e, exc_info=True)

random.Random(seed).shuffle(all_samples)
logger.info("Shuffled %d samples (seed=%d)", len(all_samples), seed)

if max_samples and len(all_samples) > max_samples:
all_samples = all_samples[:max_samples]
logger.info("Trimmed to max_samples=%d", max_samples)

logger.info("Returning %d samples total", len(all_samples))
return all_samples
116 changes: 116 additions & 0 deletions pipelinerl/domains/swe/repair.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import logging
from typing import Dict, List

logger = logging.getLogger(__name__)

SYSTEM_PROMPT = "You are a helpful coding assistant that analyzes code and fixes bugs."

USER_PROMPT_TEMPLATE = (
"Analyze the following code to find and fix bugs. Use this format:\n\n"
"<think>\n"
"[Your analysis process - be as detailed as you want until you're confident in your solution]\n"
"</think>\n\n"
"<solution>\n"
"[Your SEARCH/REPLACE edits using this format:]\n\n"
"```\n"
"### filename.py\n"
"<<<<<<< SEARCH\n"
"[exact code to find]\n"
"=======\n"
"[replacement code]\n"
">>>>>>> REPLACE\n"
"```\n"
"</solution>\n\n"
"IMPORTANT REQUIREMENTS:\n"
"- Every SEARCH/REPLACE edit must use the exact format above\n"
"- The SEARCH block must contain a contiguous chunk of lines that exist in the source code\n"
"- PROPER INDENTATION IS CRITICAL - if you want to add ' print(x)', you must include all those spaces\n"
"- Wrap each SEARCH/REPLACE edit in a code block\n"
"- Use separate code blocks for multiple edits\n\n"
"Example:\n"
"```python\n"
"### mathweb/flask/app.py\n"
"<<<<<<< SEARCH\n"
"from flask import Flask\n"
"=======\n"
"import math\n"
"from flask import Flask\n"
">>>>>>> REPLACE\n"
"```\n\n"
"Here is the issue:\n"
"--- BEGIN ISSUE ---\n"
"{problem_statement}\n"
"--- END ISSUE ---\n\n"
"Below are the code files that may contain bugs:\n"
"{file_contents}"
)


def build_messages(problem_statement: str, file_contents: Dict[str, str]) -> List[dict]:
"""Build the chat messages for a single-turn repair prompt."""
formatted_files = "".join(
f"### {path}\n```\n{content}\n```\n\n"
for path, content in file_contents.items()
)
user_content = USER_PROMPT_TEMPLATE.format(
problem_statement=problem_statement,
file_contents=formatted_files,
)
return [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
]


def parse_edits(completion: str) -> List[dict]:
"""
Parse SEARCH/REPLACE blocks from a model completion.

Each block is a '### filepath' line followed by a
<<<<<<< SEARCH / ======= / >>>>>>> REPLACE triple. Triple-backtick code
fences around the block are accepted but not required.
Returns a list of {'file_path', 'search', 'replace'} dicts.
"""
edits: List[dict] = []
lines = completion.split('\n')
n = len(lines)
i = 0
while i < n:
if '<<<<<<< SEARCH' not in lines[i]:
i += 1
continue

# Walk back to the most recent '### filepath' line, but don't cross a
# previous '>>>>>>> REPLACE' marker (that path belongs to an earlier edit).
file_path = None
for j in range(i - 1, -1, -1):
if '>>>>>>> REPLACE' in lines[j]:
break
stripped = lines[j].strip()
if stripped.startswith('###'):
file_path = stripped[3:].strip()
break
if not file_path:
i += 1
continue

search_start = i + 1
sep = replace_end = None
for k in range(search_start, n):
if sep is None and '=======' in lines[k]:
sep = k
elif sep is not None and '>>>>>>> REPLACE' in lines[k]:
replace_end = k
break

if sep is None or replace_end is None:
i += 1
continue

edits.append({
'file_path': file_path,
'search': '\n'.join(lines[search_start:sep]),
'replace': '\n'.join(lines[sep + 1:replace_end]),
})
i = replace_end + 1
return edits
Loading