diff --git a/src/program/BlockStatement.js b/src/program/BlockStatement.js index 43c4fba0..05ced59c 100644 --- a/src/program/BlockStatement.js +++ b/src/program/BlockStatement.js @@ -176,6 +176,53 @@ export default class BlockStatement extends Node { super.transpile(code, transforms); + if (transforms.asyncAwait && this.isFunctionBlock && this.parent.async && this.body.length) { + const first = this.body[0]; + const last = this.body[this.body.length - 1]; + const hasOnlyOneLine = this.body.length === 1; + + // TODO refactor :) + if (this.parent.type === 'FunctionDeclaration') { + if (hasOnlyOneLine) { + if (first.type === 'ReturnStatement') { + code.insertLeft(first.argument.start, 'Promise.resolve().then(function() { '); + code.insertLeft(first.end, ' })'); + } else { + code.insertLeft(first.start, 'return Promise.resolve().then(function() { '); + code.insertRight(last.end, ' }).then(function() {})'); + } + } else { + code.insertLeft(first.start, 'return Promise.resolve()'); + code.insertRight(last.end, '.then(function() {})'); + + for (let i = 0; i < this.body.length; i++) { + const prev = this.body[i - 1]; + const cur = this.body[i]; + const next = this.body[i + 1]; + + if (cur.expression.type === 'AwaitExpression') { + code.insertLeft(cur.start, '.then(function() { '); + code.insertRight(cur.end, ' })'); + } else { + if (!prev || prev.expression.type === 'AwaitExpression') { + code.insertLeft(cur.start, '.then(function() { '); + } + + if (!next || next.expression.type === 'AwaitExpression') { + code.insertRight(cur.end, ' })'); + } + } + } + } + + } else if (this.parent.type === 'ArrowFunctionExpression') { + // TODO merge with ^ + // wrap the function's body in a promise + code.insertLeft(first.start + 1, 'Promise.resolve().then(function() { return '); + code.insertLeft(last.end, ' })'); + } + } + if (this.createdDeclarations.length) { introStatementGenerators.push((start, prefix, suffix) => { const assignment = `${prefix}var ${this.createdDeclarations.join(', ')}${suffix}`; diff --git a/src/program/types/ArrowFunctionExpression.js b/src/program/types/ArrowFunctionExpression.js index 97e5e9ba..eecc6507 100644 --- a/src/program/types/ArrowFunctionExpression.js +++ b/src/program/types/ArrowFunctionExpression.js @@ -1,6 +1,7 @@ import Node from '../Node.js'; import CompileError from '../../utils/CompileError.js'; import removeTrailingComma from '../../utils/removeTrailingComma.js'; +import AwaitExpression from './AwaitExpression'; export default class ArrowFunctionExpression extends Node { initialise(transforms) { @@ -22,6 +23,7 @@ export default class ArrowFunctionExpression extends Node { } code.remove(charIndex, this.body.start); + AwaitExpression.removeAsync(code, transforms, this.async, this.start); super.transpile(code, transforms); // wrap naked parameter @@ -38,6 +40,7 @@ export default class ArrowFunctionExpression extends Node { code.prependRight(this.start, 'function '); } } else { + AwaitExpression.removeAsync(code, transforms, this.async, this.start); super.transpile(code, transforms); } diff --git a/src/program/types/AwaitExpression.js b/src/program/types/AwaitExpression.js index 8028df5a..b618e5d9 100644 --- a/src/program/types/AwaitExpression.js +++ b/src/program/types/AwaitExpression.js @@ -1,11 +1,20 @@ import Node from '../Node.js'; -import CompileError from '../../utils/CompileError.js'; + +const noop = () => {}; export default class AwaitExpression extends Node { - initialise(transforms) { - if (transforms.asyncAwait) { - CompileError.missingTransform("await", "asyncAwait", this); + static removeAsync(code, transforms, async, start, callback = noop) { + if (transforms.asyncAwait && async) { + code.remove(start, start + 6); + callback(); } - super.initialise(transforms); + } + + transpile (code, transforms) { + AwaitExpression.removeAsync(code, transforms, true, this.start, () => { + code.insertLeft(this.argument.start, 'return '); + }); + + super.transpile(code, transforms); } } diff --git a/src/program/types/FunctionDeclaration.js b/src/program/types/FunctionDeclaration.js index 03db09ae..9c2f58e1 100644 --- a/src/program/types/FunctionDeclaration.js +++ b/src/program/types/FunctionDeclaration.js @@ -1,6 +1,7 @@ import Node from '../Node.js'; import CompileError from '../../utils/CompileError.js'; import removeTrailingComma from '../../utils/removeTrailingComma.js'; +import AwaitExpression from './AwaitExpression'; export default class FunctionDeclaration extends Node { initialise(transforms) { @@ -20,6 +21,7 @@ export default class FunctionDeclaration extends Node { } transpile(code, transforms) { + AwaitExpression.removeAsync(code, transforms, this.async, this.start); super.transpile(code, transforms); if (transforms.trailingFunctionCommas && this.params.length) { removeTrailingComma(code, this.params[this.params.length - 1].end); diff --git a/test/samples/async-await.js b/test/samples/async-await.js new file mode 100644 index 00000000..49a939f8 --- /dev/null +++ b/test/samples/async-await.js @@ -0,0 +1,25 @@ +module.exports = [ + { + description: 'transpiles await arrow function call', + input: `async () => await a()`, + output: `!function() { return Promise.resolve().then(function() { return a() }); }` + }, + + { + description: 'transpiles await function call', + input: `async function f() { await a(); }`, + output: `function f() { return Promise.resolve().then(function() { return a(); }).then(function() {}) }` + }, + + { + description: 'transpiles await function call with return statement', + input: `async function f() { return await a(); }`, + output: `function f() { return Promise.resolve().then(function() { return a(); }) }` + }, + + { + description: 'transpiles await function call with more than one line of code', + input: `async function f() { await a(); thing(); await a2(); stuff(); await a3(); await a4(); }`, + output: `function f() { return Promise.resolve().then(function() { return a(); }) .then(function() { thing(); }) .then(function() { return a2(); }) .then(function() { stuff(); }) .then(function() { return a3(); }) .then(function() { return a4(); }).then(function() {}) }` + } +];