Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions ggml/include/ggml-backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,11 @@ extern "C" {

GGML_API void ggml_backend_synchronize(ggml_backend_t backend);

GGML_API bool ggml_backend_supports_graph_plan(ggml_backend_t backend);
GGML_API bool ggml_backend_supports_graph_plan_update(ggml_backend_t backend);
GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API void ggml_backend_graph_plan_update(ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph * cgraph);

GGML_API enum ggml_status ggml_backend_graph_plan_compute (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API enum ggml_status ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
Expand Down
112 changes: 111 additions & 1 deletion ggml/src/ggml-backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,18 @@ void ggml_backend_synchronize(ggml_backend_t backend) {
backend->iface.synchronize(backend);
}

bool ggml_backend_supports_graph_plan(ggml_backend_t backend) {
GGML_ASSERT(backend);

return (bool) backend->iface.graph_plan_create;
}

bool ggml_backend_supports_graph_plan_update(ggml_backend_t backend) {
GGML_ASSERT(backend);

return (bool) backend->iface.graph_plan_update;
}

ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.graph_plan_create != NULL);
Expand All @@ -434,6 +446,13 @@ void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_pla
backend->iface.graph_plan_free(backend, plan);
}

void ggml_backend_graph_plan_update(ggml_backend_t backend, ggml_backend_graph_plan_t plan, const struct ggml_cgraph* cgraph) {
GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.graph_plan_update != NULL);

backend->iface.graph_plan_update(backend, plan, cgraph);
}

enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.graph_plan_compute != NULL);
Expand Down Expand Up @@ -771,6 +790,11 @@ struct ggml_backend_sched_split {
struct ggml_cgraph graph;
};

struct ggml_backend_sched_plan {
int backend_id;
ggml_backend_graph_plan_t plan;
};

struct ggml_backend_sched {
bool is_reset; // true if the scheduler has been reset since the last graph split
bool is_alloc;
Expand Down Expand Up @@ -800,6 +824,12 @@ struct ggml_backend_sched {
int n_splits;
int splits_capacity;

// graph plans
struct ggml_backend_sched_plan * plans;
int n_plans;
int plans_capacity;
bool plan_needs_update;

// pipeline parallelism support
int n_copies;
int cur_copy;
Expand Down Expand Up @@ -1010,6 +1040,16 @@ static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, stru
}
}

static void ggml_backend_sched_free_plans(ggml_backend_sched_t sched) {
for (int i = 0; i < sched->n_plans; i++) {
ggml_backend_t backend = sched->backends[sched->plans[i].backend_id];
if (ggml_backend_supports_graph_plan(backend)) {
ggml_backend_graph_plan_free(backend, sched->plans[i].plan);
}
}
sched->n_plans = 0;
}

// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
// reset splits
Expand Down Expand Up @@ -1484,6 +1524,7 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra
for (int i = 0; i < sched->n_splits; ++i) {
sched->splits[i].graph.uid = ggml_graph_next_uid();
}
sched->plan_needs_update = true;
}

static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
Expand Down Expand Up @@ -1538,6 +1579,62 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
return true;
}

static void ggml_backend_sched_update_plans(ggml_backend_sched_t sched) {
// create graph plans
if (sched->plan_needs_update) {
bool create_new_plans;
if (sched->n_plans == sched->n_splits) {
create_new_plans = false;
for (int i = 0; i < sched->n_splits; i++) {
if (sched->splits[i].backend_id != sched->plans[i].backend_id) {
create_new_plans = true;
break;
}
}
} else {
create_new_plans = true;
}
if (create_new_plans) {
// free previous and recreate new plans
ggml_backend_sched_free_plans(sched);
if (sched->plans_capacity < sched->n_splits) {
while (sched->plans_capacity < sched->n_splits) {
sched->plans_capacity *= 2;
}
sched->plans = (ggml_backend_sched_plan *) realloc(
sched->plans, sched->plans_capacity * sizeof(struct ggml_backend_sched_plan));
GGML_ASSERT(sched->plans);
}
sched->n_plans = sched->n_splits;
for (int i = 0; i < sched->n_splits; i++) {
ggml_backend_t backend = sched->backends[sched->splits[i].backend_id];
sched->plans[i].backend_id = sched->splits[i].backend_id;
if (ggml_backend_supports_graph_plan(backend)) {
sched->plans[i].plan = ggml_backend_graph_plan_create(backend, &sched->splits[i].graph);
} else {
sched->plans[i].plan = nullptr;
}
}
} else {
// update existing plans
for (int i = 0; i < sched->n_splits; i++) {
ggml_backend_t backend = sched->backends[sched->splits[i].backend_id];
if (ggml_backend_supports_graph_plan(backend)) {
if (ggml_backend_supports_graph_plan_update(backend)) {
ggml_backend_graph_plan_update(backend, sched->plans[i].plan, &sched->splits[i].graph);
} else {
ggml_backend_graph_plan_free(backend, sched->plans[i].plan);
sched->plans[i].plan = ggml_backend_graph_plan_create(backend, &sched->splits[i].graph);
}
} else {
sched->plans[i].plan = nullptr;
}
}
}
sched->plan_needs_update = false;
}
}

static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
GGML_ASSERT(sched);
struct ggml_backend_sched_split * splits = sched->splits;
Expand All @@ -1546,6 +1643,8 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
std::vector<int32_t> ids;
std::vector<ggml_bitset_t> used_ids;

ggml_backend_sched_update_plans(sched);

for (int split_id = 0; split_id < sched->n_splits; split_id++) {
struct ggml_backend_sched_split * split = &splits[split_id];
int split_backend_id = split->backend_id;
Expand Down Expand Up @@ -1675,7 +1774,12 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s
}

if (!sched->callback_eval) {
enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
enum ggml_status ec;
if (ggml_backend_supports_graph_plan(split_backend) && sched->plans[split_id].plan) {
ec = ggml_backend_graph_plan_compute(split_backend, sched->plans[split_id].plan);
} else {
ec = ggml_backend_graph_compute_async(split_backend, &split->graph);
}
if (ec != GGML_STATUS_SUCCESS) {
return ec;
}
Expand Down Expand Up @@ -1773,6 +1877,10 @@ ggml_backend_sched_t ggml_backend_sched_new(
sched->splits = (ggml_backend_sched_split *) calloc(initial_splits_capacity, sizeof(sched->splits[0]));
sched->splits_capacity = initial_splits_capacity;

const int initial_plans_capacity = 16;
sched->plans = (ggml_backend_sched_plan *) calloc(initial_plans_capacity, sizeof(sched->plans[0]));
sched->plans_capacity = initial_plans_capacity;

for (int b = 0; b < n_backends; b++) {
sched->backends[b] = backends[b];
sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]);
Expand Down Expand Up @@ -1806,6 +1914,8 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
ggml_free(sched->ctx);
ggml_hash_set_free(&sched->hash_set);
free(sched->splits);
ggml_backend_sched_free_plans(sched);
free(sched->plans);
free(sched->hv_tensor_backend_ids);
free(sched->hv_tensor_copies);
free(sched->node_backend_ids);
Expand Down
Loading