-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathktensor.h
More file actions
88 lines (76 loc) · 2.82 KB
/
Copy pathktensor.h
File metadata and controls
88 lines (76 loc) · 2.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#pragma once
#include <vector>
#include <cstdint>
#include <cstring>
#include <cmath>
#include <algorithm>
struct KTensor {
std::vector<float> data;
int64_t shape[4] = {0, 0, 0, 0};
int ndim = 0;
KTensor() = default;
static KTensor zeros(std::initializer_list<int64_t> dims) {
KTensor t;
t.ndim = static_cast<int>(dims.size());
int i = 0;
int64_t n = 1;
for (auto d : dims) { t.shape[i++] = d; n *= d; }
t.data.assign(n, 0.0f);
return t;
}
static KTensor from_data(const float* src, std::initializer_list<int64_t> dims) {
KTensor t;
t.ndim = static_cast<int>(dims.size());
int i = 0;
int64_t n = 1;
for (auto d : dims) { t.shape[i++] = d; n *= d; }
t.data.assign(src, src + n);
return t;
}
bool defined() const { return !data.empty(); }
float* ptr() { return data.data(); }
const float* ptr() const { return data.data(); }
int64_t size(int dim) const { return shape[dim]; }
int64_t numel() const { return static_cast<int64_t>(data.size()); }
float& at(int64_t i) { return data[i]; }
const float& at(int64_t i) const { return data[i]; }
float& at2(int64_t r, int64_t c) { return data[r * shape[1] + c]; }
const float& at2(int64_t r, int64_t c) const { return data[r * shape[1] + c]; }
float& at3(int64_t d0, int64_t d1, int64_t d2) {
return data[(d0 * shape[1] + d1) * shape[2] + d2];
}
const float& at3(int64_t d0, int64_t d1, int64_t d2) const {
return data[(d0 * shape[1] + d1) * shape[2] + d2];
}
void copy_slice1(const KTensor& src, int64_t count) {
int64_t n = std::min(count, std::min(shape[1], src.shape[1]));
std::memcpy(data.data(), src.data.data(), n * sizeof(float));
}
void copy_row(int64_t dst_col, const float* src, int64_t dim0, int64_t src_stride) {
for (int64_t d = 0; d < dim0; d++) {
data[d * shape[ndim - 1] + dst_col] = src[d * src_stride];
}
}
void copy_col_3d(int64_t dst_col, const float* col_data, int64_t dim1) {
for (int64_t d = 0; d < dim1; d++) {
at3(0, d, dst_col) = col_data[d];
}
}
void truncate_last_dim(int64_t new_size) {
if (ndim < 1 || new_size >= shape[ndim - 1]) return;
int64_t outer = 1;
for (int i = 0; i < ndim - 1; i++) outer *= shape[i];
if (outer == 1) {
data.resize(new_size);
} else {
std::vector<float> tmp(outer * new_size);
for (int64_t o = 0; o < outer; o++) {
std::memcpy(tmp.data() + o * new_size,
data.data() + o * shape[ndim - 1],
new_size * sizeof(float));
}
data = std::move(tmp);
}
shape[ndim - 1] = new_size;
}
};