Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions services/ml/compilation_impl_nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,11 @@ CompilationImplNN::CompilationImplNN(const ModelImplNN* model,
}

CompilationImplNN::~CompilationImplNN() {
// ANeuralNetworksCompilation_free(nn_compilation_);
// The nn_compilation_ will be deleted in execution phase.
#if defined(OS_ANDROID)
ANeuralNetworksCompilation_free(nn_compilation_);
#else
IE(ie_compilation_free)(ie_compilation_);
#endif
}

void CompilationImplNN::Finish(int32_t preference, FinishCallback callback) {
Expand Down
24 changes: 8 additions & 16 deletions services/ml/execution_impl_nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,17 @@

namespace ml {

// TODO:: CompilationImplNN* => std::unique<CompilationImplNN> so that
// ie_compilation_free(ie_compilation_); can host in class CompilationImplNN.
ExecutionImplNN::ExecutionImplNN(const CompilationImplNN* compilation,
mojo::ScopedSharedBufferHandle memory)
: operands_(compilation->operands_),
operations_(compilation->operations_),
inputs_(compilation->inputs_),
outputs_(compilation->outputs_),
memory_(std::move(memory)),
#if defined(OS_ANDROID)
nn_compilation_(compilation->nn_compilation_) {
#else
ie_compilation_(compilation->ie_compilation_) {
#endif
compilation_impl_(compilation) {
#if defined(OS_LINUX) || defined(OS_WIN)
// Create Execution
IE(ie_execution_create)(ie_compilation_, &ie_execution_);
IE(ie_execution_create)(compilation_impl_->ie_compilation_, &ie_execution_);
#endif
uint32_t total_length = 0;
inputs_info_.reserve(inputs_.size());
Expand All @@ -54,9 +48,7 @@ ExecutionImplNN::ExecutionImplNN(const CompilationImplNN* compilation,

ExecutionImplNN::~ExecutionImplNN() {
#if defined(OS_ANDROID)
ANeuralNetworksCompilation_free(nn_compilation_);
#else
IE(ie_compilation_free)(ie_compilation_);
IE(ie_execution_free)(ie_execution_);
#endif
DLOG(INFO) << "ANeuralNetworksCompilation_free";
Expand Down Expand Up @@ -91,8 +83,8 @@ void ExecutionImplNN::StartCompute(mojom::UserBufferPtr user_buffer,
int32_t result = 0;
#if defined(OS_ANDROID)
ANeuralNetworksExecution* nn_execution;
result =
ANeuralNetworksExecution_create(nn_compilation_, &nn_execution);
result = ANeuralNetworksExecution_create(compilation_impl_->nn_compilation_,
&nn_execution);
#endif
for (size_t i = 0; i < inputs_info_.size(); ++i) {
std::unique_ptr<OperandInfo>& info = inputs_info_[i];
Expand All @@ -101,8 +93,8 @@ void ExecutionImplNN::StartCompute(mojom::UserBufferPtr user_buffer,
nn_execution, i, NULL, static_cast<void*>(info->mapping.get()),
info->length);
#else
result = IE(ie_execution_set_input)(ie_execution_, i,
info->mapping.get(), info->length);
result = IE(ie_execution_set_input)(ie_execution_, i, info->mapping.get(),
info->length);
#endif
}

Expand All @@ -113,8 +105,8 @@ void ExecutionImplNN::StartCompute(mojom::UserBufferPtr user_buffer,
nn_execution, i, NULL, static_cast<void*>(info->mapping.get()),
info->length);
#else
result = IE(ie_execution_set_output)(
ie_execution_, i, info->mapping.get(), info->length);
result = IE(ie_execution_set_output)(ie_execution_, i, info->mapping.get(),
info->length);
#endif
}

Expand Down
10 changes: 4 additions & 6 deletions services/ml/execution_impl_nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <vector>

#include "base/macros.h"
#include "base/memory/scoped_refptr.h"
#include "services/ml/common.h"
#include "services/ml/compilation_impl_nn.h"
#include "services/ml/model_impl_nn.h"
Expand All @@ -28,8 +29,7 @@ namespace ml {

class ExecutionImplNN : public mojom::Execution {
public:
ExecutionImplNN(const CompilationImplNN*,
mojo::ScopedSharedBufferHandle);
ExecutionImplNN(const CompilationImplNN*, mojo::ScopedSharedBufferHandle);
~ExecutionImplNN() override;

void StartCompute(mojom::UserBufferPtr user_buffer,
Expand All @@ -46,13 +46,11 @@ class ExecutionImplNN : public mojom::Execution {
std::vector<std::unique_ptr<OperandInfo>> inputs_info_;
std::vector<std::unique_ptr<OperandInfo>> outputs_info_;
mojo::ScopedSharedBufferHandle memory_;

const CompilationImplNN* compilation_impl_;
#if defined(OS_LINUX) || defined(OS_WIN)
ie_compilation_t* ie_compilation_;
ie_execution_t* ie_execution_;
#else
ANeuralNetworksCompilation* nn_compilation_;
#endif

DISALLOW_COPY_AND_ASSIGN(ExecutionImplNN);
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class NeuralNetworkContext final : public ScriptWrappable,
static const unsigned long kTensorFloat32 = 3;
static const unsigned long kTensorInt32 = 4;
static const unsigned long kTensorQuant8Asymm = 5;
static const unsigned long kBool = 6;
static const unsigned long kTensorQuant8SymmPerChannel = 11;
static const unsigned long kTensorQuant8AsymmSigned = 14;

Expand Down
14 changes: 14 additions & 0 deletions third_party/blink/renderer/modules/ml/v2/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,24 @@ blink_modules_sources("v2") {
"ops/binary.h",
"ops/constant.cc",
"ops/constant.h",
"ops/conv.cc",
"ops/conv.h",
"ops/input.cc",
"ops/input.h",
"ops/matmul.cc",
"ops/matmul.h",
"ops/output.cc",
"ops/output.h",
"ops/pooling.cc",
"ops/pooling.h",
"ops/relu.cc",
"ops/relu.h",
"ops/reshape.cc",
"ops/reshape.h",
"ops/softmax.cc",
"ops/softmax.h",
"ops/transpose.cc",
"ops/transpose.h",
]

public_deps = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void NNCompilation::OnCreateExecution(

if (result_code == ml::mojom::blink::NOT_ERROR) {
resolver->Resolve(MakeGarbageCollected<NNExecution>(
std::move(init_params), std::move(name_index_)));
std::move(init_params), name_index_));
} else {
resolver->Reject(MakeGarbageCollected<DOMException>(
DOMExceptionCode::kInvalidStateError,
Expand Down
102 changes: 102 additions & 0 deletions third_party/blink/renderer/modules/ml/v2/nn_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,26 @@
#include "third_party/blink/renderer/modules/ml/v2/nn_model.h"
#include "third_party/blink/renderer/modules/ml/v2/ops/binary.h"
#include "third_party/blink/renderer/modules/ml/v2/ops/constant.h"
#include "third_party/blink/renderer/modules/ml/v2/ops/conv.h"
#include "third_party/blink/renderer/modules/ml/v2/ops/input.h"
#include "third_party/blink/renderer/modules/ml/v2/ops/matmul.h"
#include "third_party/blink/renderer/modules/ml/v2/ops/pooling.h"
#include "third_party/blink/renderer/modules/ml/v2/ops/relu.h"
#include "third_party/blink/renderer/modules/ml/v2/ops/reshape.h"
#include "third_party/blink/renderer/modules/ml/v2/ops/softmax.h"
#include "third_party/blink/renderer/modules/ml/v2/ops/transpose.h"
#include "third_party/blink/renderer/platform/bindings/exception_code.h"

namespace blink {

int32_t product(const WTF::Vector<int32_t>& dims) {
uint32_t prod = 1;
for (auto dim : dims)
prod *= dim;

return prod;
}

namespace {

bool InvalidData(ExceptionState& state) {
Expand Down Expand Up @@ -104,6 +119,26 @@ bool InvalidOperandValue(const OperandDescriptor* descriptor,
return invalid ? InvalidData(exception_state) : false;
}

bool InvalidStrides(WTF::Vector<int32_t>& padding,
WTF::Vector<int32_t>& strides,
WTF::Vector<int32_t>& dilations,
ExceptionState& state) {
if (padding.IsEmpty()) {
padding = Vector<int32_t>(4, 0);
}
if (strides.IsEmpty()) {
strides = Vector<int32_t>(2, 1);
}
if (dilations.IsEmpty()) {
dilations = Vector<int32_t>(2, 1);
}
if (product(padding) < 0 || product(strides) <= 0 ||
product(dilations) <= 0) {
return InvalidData(state);
}
return false;
}

} // namespace

NNContext::NNContext(NavigatorML* navigator_ml)
Expand Down Expand Up @@ -140,6 +175,73 @@ Operand* NNContext::mul(Operand* primary, Operand* secondary) {
return MakeGarbageCollected<Binary>(kBinaryTypeMul, primary, secondary);
}

Operand* NNContext::conv2d(Operand* input,
Operand* filter,
WTF::Vector<int32_t> padding,
WTF::Vector<int32_t> strides,
WTF::Vector<int32_t> dilations,
int32_t groups,
String layout,
ExceptionState& state) {
if (InvalidStrides(padding, strides, dilations, state))
return nullptr;
return MakeGarbageCollected<Conv>(input, filter, std::move(padding),
std::move(strides), std::move(dilations),
groups, layout);
}

Operand* NNContext::averagePool2d(Operand* input,
WTF::Vector<int32_t> window_dimensions,
WTF::Vector<int32_t> padding,
WTF::Vector<int32_t> strides,
WTF::Vector<int32_t> dilations,
String layout,
ExceptionState& state) {
if (InvalidStrides(padding, strides, dilations, state))
return nullptr;
if (window_dimensions.IsEmpty())
window_dimensions = WTF::Vector<int32_t>(2, 0);
return MakeGarbageCollected<Pooling>(
input, std::move(window_dimensions), std::move(padding),
std::move(strides), std::move(dilations), layout, kPoolingTypeAverage);
}

Operand* NNContext::maxPool2d(Operand* input,
WTF::Vector<int32_t> window_dimensions,
WTF::Vector<int32_t> padding,
WTF::Vector<int32_t> strides,
WTF::Vector<int32_t> dilations,
String layout,
ExceptionState& state) {
if (InvalidStrides(padding, strides, dilations, state))
return nullptr;
if (window_dimensions.IsEmpty())
window_dimensions = WTF::Vector<int32_t>(2, 0);
return MakeGarbageCollected<Pooling>(
input, std::move(window_dimensions), std::move(padding),
std::move(strides), std::move(dilations), layout, kPoolingTypeMax);
}

Operand* NNContext::reshape(Operand* input, WTF::Vector<int32_t> new_shape) {
return MakeGarbageCollected<Reshape>(input, std::move(new_shape));
}

Operand* NNContext::softmax(Operand* input) {
return MakeGarbageCollected<Softmax>(input);
}

Operand* NNContext::relu(Operand* input) {
return MakeGarbageCollected<Relu>(input);
}

Operand* NNContext::matmul(Operand* a, Operand* b) {
return MakeGarbageCollected<MatMul>(a, b);
}

Operand* NNContext::transpose(Operand* input, WTF::Vector<int32_t> new_shape) {
return MakeGarbageCollected<Transpose>(input, std::move(new_shape));
}

ScriptPromise NNContext::createModel(ScriptState* script_state,
const NamedOperandVector& outputs) {
auto* resolver = MakeGarbageCollected<ScriptPromiseResolver>(script_state);
Expand Down
29 changes: 29 additions & 0 deletions third_party/blink/renderer/modules/ml/v2/nn_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class NavigatorML;

using NamedOperandVector = HeapVector<Member<NamedOperand>>;

int32_t product(const WTF::Vector<int32_t>&);

class NNContext final : public ScriptWrappable,
public ExecutionContextLifecycleObserver {
DEFINE_WRAPPERTYPEINFO();
Expand All @@ -42,7 +44,34 @@ class NNContext final : public ScriptWrappable,
ExceptionState&);
Operand* add(Operand*, Operand*);
Operand* mul(Operand*, Operand*);
Operand* conv2d(Operand*,
Operand*,
WTF::Vector<int32_t>,
WTF::Vector<int32_t>,
WTF::Vector<int32_t>,
int32_t,
String,
ExceptionState&);
Operand* averagePool2d(Operand*,
WTF::Vector<int32_t>,
WTF::Vector<int32_t>,
WTF::Vector<int32_t>,
WTF::Vector<int32_t>,
String,
ExceptionState&);
Operand* maxPool2d(Operand*,
WTF::Vector<int32_t>,
WTF::Vector<int32_t>,
WTF::Vector<int32_t>,
WTF::Vector<int32_t>,
String,
ExceptionState&);
Operand* reshape(Operand*, WTF::Vector<int32_t>);
Operand* softmax(Operand*);
Operand* relu(Operand*);
Operand* matmul(Operand*, Operand*);
ScriptPromise createModel(ScriptState*, const NamedOperandVector&);
Operand* transpose(Operand*, WTF::Vector<int32_t>);

// ExecutionContextLifecycleObserver overrides.
void ContextDestroyed() override;
Expand Down
21 changes: 20 additions & 1 deletion third_party/blink/renderer/modules/ml/v2/nn_context.idl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,29 @@ interface NNContext {
[RaisesException] Operand input(DOMString name, OperandDescriptor desc);
[RaisesException] Operand constant(OperandDescriptor desc, [AllowShared] ArrayBufferView data);

// element-wise binary operatioins.
// element-wise binary operations.
Operand add(Operand primary, Operand secondary);
Operand mul(Operand primary, Operand secondary);

[RaisesException] Operand conv2d(Operand input, Operand filter,
optional sequence<long> padding = [], optional sequence<long> strides = [],
optional sequence<long> dilations = [], optional long groups = 1,
optional OperandLayout layout = "nchw");

// Pooling operations.
[RaisesException] Operand averagePool2d(Operand input, optional sequence<long> windowDimensions = [],
optional sequence<long> padding = [], optional sequence<long> strides = [],
optional sequence<long> dilations = [], optional OperandLayout layout = "nchw");
[RaisesException] Operand maxPool2d(Operand input, optional sequence<long> windowDimensions = [],
optional sequence<long> padding = [], optional sequence<long> strides = [],
optional sequence<long> dilations = [], optional OperandLayout layout = "nchw");

Operand reshape(Operand input, sequence<long> newShape);
Operand softmax(Operand input);
Operand relu(Operand input);
Operand matmul(Operand a, Operand b);
Operand transpose(Operand input, optional sequence<long> permutation=[]);

// Create Model
[CallWith=ScriptState] Promise<Model> createModel(sequence<NamedOperand> outputs);
};
Loading