Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 16 additions & 34 deletions tensorflow/lite/micro/kernels/cmsis_nn/maximum_minimum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,10 @@ namespace tflite {

namespace {

cmsis_nn_dims FillVariableShape(int32_t rank, int32_t* tensor_dims) {
if (rank == 4) {
return {tensor_dims[0], tensor_dims[1], tensor_dims[2], tensor_dims[3]};
} else if (rank == 3) {
return {1, tensor_dims[0], tensor_dims[1], tensor_dims[2]};
} else if (rank == 2) {
return {1, 1, tensor_dims[0], tensor_dims[1]};
} else {
return {1, 1, 1, 1};
}
cmsis_nn_dims FillVariableShape(const RuntimeShape& shape) {
RuntimeShape extended_shape = RuntimeShape::ExtendedShape(4, shape);
return {extended_shape.Dims(0), extended_shape.Dims(1),
extended_shape.Dims(2), extended_shape.Dims(3)};
}

TfLiteStatus EvalMaximum(TfLiteContext* context, TfLiteNode* node) {
Expand All @@ -55,12 +49,9 @@ TfLiteStatus EvalMaximum(TfLiteContext* context, TfLiteNode* node) {
RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2);
RuntimeShape output_shape = tflite::micro::GetTensorShape(output);

cmsis_nn_dims input_1_dims = FillVariableShape(
input_1_shape.DimensionsCount(), input_1_shape.DimsData());
cmsis_nn_dims input_2_dims = FillVariableShape(
input_2_shape.DimensionsCount(), input_2_shape.DimsData());
cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(),
output_shape.DimsData());
cmsis_nn_dims input_1_dims = FillVariableShape(input_1_shape);
cmsis_nn_dims input_2_dims = FillVariableShape(input_2_shape);
cmsis_nn_dims output_dims = FillVariableShape(output_shape);

switch (op_context.output->type) {
case kTfLiteInt8:
Expand Down Expand Up @@ -107,12 +98,9 @@ TfLiteStatus EvalMaximumInt8(TfLiteContext* context, TfLiteNode* node) {
RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2);
RuntimeShape output_shape = tflite::micro::GetTensorShape(output);

cmsis_nn_dims input_1_dims = FillVariableShape(
input_1_shape.DimensionsCount(), input_1_shape.DimsData());
cmsis_nn_dims input_2_dims = FillVariableShape(
input_2_shape.DimensionsCount(), input_2_shape.DimsData());
cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(),
output_shape.DimsData());
cmsis_nn_dims input_1_dims = FillVariableShape(input_1_shape);
cmsis_nn_dims input_2_dims = FillVariableShape(input_2_shape);
cmsis_nn_dims output_dims = FillVariableShape(output_shape);

switch (op_context.output->type) {
case kTfLiteInt8:
Expand Down Expand Up @@ -147,12 +135,9 @@ TfLiteStatus EvalMinimum(TfLiteContext* context, TfLiteNode* node) {
RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2);
RuntimeShape output_shape = tflite::micro::GetTensorShape(output);

cmsis_nn_dims input_1_dims = FillVariableShape(
input_1_shape.DimensionsCount(), input_1_shape.DimsData());
cmsis_nn_dims input_2_dims = FillVariableShape(
input_2_shape.DimensionsCount(), input_2_shape.DimsData());
cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(),
output_shape.DimsData());
cmsis_nn_dims input_1_dims = FillVariableShape(input_1_shape);
cmsis_nn_dims input_2_dims = FillVariableShape(input_2_shape);
cmsis_nn_dims output_dims = FillVariableShape(output_shape);

switch (op_context.output->type) {
case kTfLiteInt8:
Expand Down Expand Up @@ -199,12 +184,9 @@ TfLiteStatus EvalMinimumInt8(TfLiteContext* context, TfLiteNode* node) {
RuntimeShape input_2_shape = tflite::micro::GetTensorShape(input2);
RuntimeShape output_shape = tflite::micro::GetTensorShape(output);

cmsis_nn_dims input_1_dims = FillVariableShape(
input_1_shape.DimensionsCount(), input_1_shape.DimsData());
cmsis_nn_dims input_2_dims = FillVariableShape(
input_2_shape.DimensionsCount(), input_2_shape.DimsData());
cmsis_nn_dims output_dims = FillVariableShape(output_shape.DimensionsCount(),
output_shape.DimsData());
cmsis_nn_dims input_1_dims = FillVariableShape(input_1_shape);
cmsis_nn_dims input_2_dims = FillVariableShape(input_2_shape);
cmsis_nn_dims output_dims = FillVariableShape(output_shape);

switch (op_context.output->type) {
case kTfLiteInt8:
Expand Down
18 changes: 13 additions & 5 deletions tensorflow/lite/micro/kernels/cmsis_nn/pad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,21 @@ TfLiteStatus PadEvalInt8(TfLiteContext* context, TfLiteNode* node) {
int8_t* output_ptr = tflite::micro::GetTensorData<int8_t>(output);

const RuntimeShape d = tflite::micro::GetTensorShape(input);
const cmsis_nn_dims input_size = {d.Dims(0), d.Dims(1), d.Dims(2), d.Dims(3)};
const int rank = d.DimensionsCount();

cmsis_nn_dims input_size = {
rank >= 4 ? d.Dims(rank - 4) : 1, rank >= 3 ? d.Dims(rank - 3) : 1,
rank >= 2 ? d.Dims(rank - 2) : 1, rank >= 1 ? d.Dims(rank - 1) : 1};

const PadParams p = data->params;
const cmsis_nn_dims pre_pad = {p.left_padding[0], p.left_padding[1],
p.left_padding[2], p.left_padding[3]};
const cmsis_nn_dims post_pad = {p.right_padding[0], p.right_padding[1],
p.right_padding[2], p.right_padding[3]};
cmsis_nn_dims pre_pad = {rank >= 4 ? p.left_padding[rank - 4] : 0,
rank >= 3 ? p.left_padding[rank - 3] : 0,
rank >= 2 ? p.left_padding[rank - 2] : 0,
rank >= 1 ? p.left_padding[rank - 1] : 0};
cmsis_nn_dims post_pad = {rank >= 4 ? p.right_padding[rank - 4] : 0,
rank >= 3 ? p.right_padding[rank - 3] : 0,
rank >= 2 ? p.right_padding[rank - 2] : 0,
rank >= 1 ? p.right_padding[rank - 1] : 0};

arm_pad_s8(input_ptr, output_ptr, pad_value, &input_size, &pre_pad,
&post_pad);
Expand Down
Loading