Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/binding/python/model/param/python_param.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ std::pair<std::string, std::string> serialize_sparse_vector(
const size_t n = sparse_dict.size();
if (n == 0) return {{}, {}};

const auto sorted_items = py::module_::import("builtins")
.attr("sorted")(sparse_dict.attr("items")());

std::string indices_buf;
indices_buf.resize(n * sizeof(uint32_t));
auto *indices_ptr = reinterpret_cast<uint32_t *>(indices_buf.data());
Expand All @@ -106,9 +109,10 @@ std::pair<std::string, std::string> serialize_sparse_vector(
auto *values_ptr = reinterpret_cast<ValueType *>(values_buf.data());

size_t i = 0;
for (const auto &[py_key, py_val] : sparse_dict) {
indices_ptr[i] = checked_cast<uint32_t>(py_key, "Sparse indices", "UINT32");
values_ptr[i] = value_caster(py_val, i);
for (const auto &item : sorted_items) {
auto tup = item.cast<py::tuple>();
indices_ptr[i] = checked_cast<uint32_t>(tup[0], "Sparse indices", "UINT32");
values_ptr[i] = value_caster(tup[1], i);
++i;
}
return {std::move(indices_buf), std::move(values_buf)};
Expand Down
36 changes: 20 additions & 16 deletions src/binding/python/model/python_doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,22 @@ void ZVecPyDoc::bind_doc(py::module_ &m) {
case DataType::SPARSE_VECTOR_FP32: {
const auto sparse_dict =
checked_cast<py::dict>(obj, field, "SPARSE_VECTOR_FP32 (dict)");
const auto sorted_items =
py::module_::import("builtins")
.attr("sorted")(sparse_dict.attr("items")());
std::vector<uint32_t> indices;
std::vector<float> values;
for (const auto &item : sparse_dict) {
indices.reserve(sparse_dict.size());
values.reserve(sparse_dict.size());
for (const auto &item : sorted_items) {
try {
indices.push_back(item.first.cast<uint32_t>());
values.push_back(item.second.cast<float>());
auto tup = item.cast<py::tuple>();
indices.push_back(tup[0].cast<uint32_t>());
values.push_back(tup[1].cast<float>());
} catch (const py::cast_error &e) {
throw py::type_error(
"Vector '" + field +
"': sparse vector key/value must be (uint32, float), "
"got key=" +
std::string(py::str(py::type::of(item.first))) +
", value=" +
std::string(py::str(py::type::of(item.second))));
"': sparse vector key/value must be (uint32, float)");
}
}
const std::pair<std::vector<uint32_t>, std::vector<float>>
Expand All @@ -188,20 +190,22 @@ void ZVecPyDoc::bind_doc(py::module_ &m) {
case DataType::SPARSE_VECTOR_FP16: {
const auto sparse_dict =
checked_cast<py::dict>(obj, field, "SPARSE_VECTOR_FP16 (dict)");
const auto sorted_items =
py::module_::import("builtins")
.attr("sorted")(sparse_dict.attr("items")());
std::vector<uint32_t> indices;
std::vector<ailego::Float16> values;
for (const auto &item : sparse_dict) {
indices.reserve(sparse_dict.size());
values.reserve(sparse_dict.size());
for (const auto &item : sorted_items) {
try {
indices.push_back(item.first.cast<uint32_t>());
values.push_back(ailego::Float16(item.second.cast<float>()));
auto tup = item.cast<py::tuple>();
indices.push_back(tup[0].cast<uint32_t>());
values.push_back(ailego::Float16(tup[1].cast<float>()));
} catch (const py::cast_error &e) {
throw py::type_error(
"Field '" + field +
"': sparse vector key/value must be (uint32, float), "
"got key=" +
std::string(py::str(py::type::of(item.first))) +
", value=" +
std::string(py::str(py::type::of(item.second))));
"': sparse vector key/value must be (uint32, float)");
}
}
const std::pair<std::vector<uint32_t>, std::vector<ailego::Float16>>
Expand Down
2 changes: 0 additions & 2 deletions src/db/collection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
#include "db/common/file_helper.h"
#include "db/common/profiler.h"
#include "db/common/typedef.h"
#include "db/index/column/vector_column/vector_column_indexer.h"
#include "db/index/common/delete_store.h"
#include "db/index/common/id_map.h"
#include "db/index/common/index_filter.h"
Expand Down Expand Up @@ -1458,7 +1457,6 @@ Result<WriteResults> CollectionImpl::write_impl(std::vector<Doc> &docs,
kMaxWriteBatchSize));
}

// validate docs
for (auto &&doc : docs) {
if (need_switch_to_new_segment()) {
auto s = switch_to_new_segment_for_writing();
Expand Down
119 changes: 99 additions & 20 deletions src/db/index/common/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ std::string get_value_type_name(const Doc::Value &value, bool is_vector) {
value);
}


namespace {

template <typename T>
T byte_swap(T value) {
if constexpr (std::is_same_v<T, float16_t>) {
Expand Down Expand Up @@ -159,6 +162,42 @@ T read_value_from_buffer(const uint8_t *&data) {
return value;
}

template <typename T>
std::string vec_to_string(const std::vector<T> &v) {
std::ostringstream oss;
oss << "[";
for (size_t i = 0; i < v.size(); ++i) {
if (i > 0) oss << ", ";
oss << +v[i]; // + from print as char
}
oss << "]";
return oss.str();
}

template <class... Ts>
struct overloaded : Ts... {
using Ts::operator()...;
};

template <class... Ts>
overloaded(Ts...) -> overloaded<Ts...>;

enum class SparseIndexCheckResult { kOk, kUnsorted, kDuplicate };

SparseIndexCheckResult check_sparse_indices(const uint32_t *indices, size_t n) {
for (size_t i = 1; i < n; ++i) {
if (indices[i] == indices[i - 1]) {
return SparseIndexCheckResult::kDuplicate;
}
if (indices[i] < indices[i - 1]) {
return SparseIndexCheckResult::kUnsorted;
}
}
return SparseIndexCheckResult::kOk;
}

} // namespace


void Doc::write_to_buffer(std::vector<uint8_t> &buffer, const void *src,
size_t size) {
Expand Down Expand Up @@ -874,6 +913,18 @@ Status Doc::validate(const CollectionSchema::Ptr &schema,
"] exceeds the maximum number of sparse indices (",
kSparseMaxDimSize, ")");
}
auto check = check_sparse_indices(sparse_indices.data(),
sparse_indices.size());
if (check == SparseIndexCheckResult::kUnsorted) {
return Status::InvalidArgument(
"Invalid doc[", pk_, "]: sparse vector field[", field_name,
"] indices are not sorted in ascending order");
}
if (check == SparseIndexCheckResult::kDuplicate) {
return Status::InvalidArgument(
"Invalid doc[", pk_, "]: sparse vector field[", field_name,
"] contains duplicate indices");
}
}
break;
}
Expand All @@ -895,6 +946,18 @@ Status Doc::validate(const CollectionSchema::Ptr &schema,
"] exceeds the maximum number of sparse indices (",
kSparseMaxDimSize, ")");
}
auto check = check_sparse_indices(sparse_indices.data(),
sparse_indices.size());
if (check == SparseIndexCheckResult::kUnsorted) {
return Status::InvalidArgument(
"Invalid doc[", pk_, "]: sparse vector field[", field_name,
"] indices are not sorted in ascending order");
}
if (check == SparseIndexCheckResult::kDuplicate) {
return Status::InvalidArgument(
"Invalid doc[", pk_, "]: sparse vector field[", field_name,
"] contains duplicate indices");
}
}
break;
}
Expand Down Expand Up @@ -1036,24 +1099,6 @@ size_t Doc::memory_usage() const {
return usage;
}

template <typename T>
std::string vec_to_string(const std::vector<T> &v) {
std::ostringstream oss;
oss << "[";
for (size_t i = 0; i < v.size(); ++i) {
if (i > 0) oss << ", ";
oss << +v[i]; // + from print as char
}
oss << "]";
return oss.str();
}

template <class... Ts>
struct overloaded : Ts... {
using Ts::operator()...;
};
template <class... Ts>
overloaded(Ts...) -> overloaded<Ts...>;

std::string Doc::to_detail_string() const {
std::stringstream oss;
Expand Down Expand Up @@ -1274,12 +1319,46 @@ Status VectorQuery::validate(const FieldSchema *schema) const {
"] is not a dense vector field");
}
} else if (schema->is_sparse_vector()) {
// Validate sparse indices size
if (query_sparse_indices_.size() > kSparseMaxDimSize * sizeof(uint32_t)) {
size_t value_byte_size = 0;
switch (schema->data_type()) {
case DataType::SPARSE_VECTOR_FP32:
value_byte_size = sizeof(float);
break;
case DataType::SPARSE_VECTOR_FP16:
value_byte_size = sizeof(float16_t);
break;
default:
return Status::InvalidArgument(
"Invalid query: sparse vector type of field[", field_name_,
"] is not supported");
}
if (query_sparse_indices_.size() % sizeof(uint32_t) != 0 ||
query_sparse_values_.size() % value_byte_size != 0 ||
query_sparse_indices_.size() / sizeof(uint32_t) !=
query_sparse_values_.size() / value_byte_size) {
return Status::InvalidArgument(
"Invalid query: sparse vector query for field[", field_name_,
"] has mismatched indices and values sizes");
}
size_t n_indices = query_sparse_indices_.size() / sizeof(uint32_t);
if (n_indices > kSparseMaxDimSize) {
return Status::InvalidArgument(
"Invalid query: too many sparse indices, the maximum allowed is ",
kSparseMaxDimSize);
}
const auto *query_indices_ptr =
reinterpret_cast<const uint32_t *>(query_sparse_indices_.data());
auto check = check_sparse_indices(query_indices_ptr, n_indices);
if (check == SparseIndexCheckResult::kUnsorted) {
return Status::InvalidArgument(
"Invalid query: sparse vector query for field[", field_name_,
"] indices are not sorted in ascending order");
}
if (check == SparseIndexCheckResult::kDuplicate) {
return Status::InvalidArgument(
"Invalid query: sparse vector query for field[", field_name_,
"] contains duplicate indices");
}
} else {
return Status::InvalidArgument("Invalid query: field[", field_name_,
"] is not a vector field");
Expand Down
Loading
Loading