diff --git a/cache.go b/cache.go index 07a1829..553e91a 100644 --- a/cache.go +++ b/cache.go @@ -3,7 +3,7 @@ package luar import ( "reflect" - "github.com/yuin/gopher-lua" + lua "github.com/yuin/gopher-lua" ) func addMethods(L *lua.LState, c *Config, vtype reflect.Type, tbl *lua.LTable, ptrReceiver bool) { @@ -20,6 +20,15 @@ func addMethods(L *lua.LState, c *Config, vtype reflect.Type, tbl *lua.LTable, p for _, name := range namesFn(vtype, method) { tbl.RawSetString(name, fn) } + + if c.PreprocessMetatables { + for j := 0; j < method.Type.NumIn(); j++ { + preprocessMetatables(L, c, method.Type.In(j)) + } + for j := 0; j < method.Type.NumOut(); j++ { + preprocessMetatables(L, c, method.Type.Out(j)) + } + } } } @@ -86,6 +95,7 @@ func addFields(L *lua.LState, c *Config, vtype reflect.Type, tbl *lua.LTable) { tbl.RawSetString(alias, ud) } } + preprocessMetatables(L, c, field.Type) } } @@ -95,6 +105,10 @@ func getMetatable(L *lua.LState, vtype reflect.Type) *lua.LTable { if v := config.regular[vtype]; v != nil { return v } + if config.processing[vtype] { + return nil + } + config.processing[vtype] = true var ( mt *lua.LTable @@ -184,7 +198,14 @@ func getMetatable(L *lua.LState, vtype reflect.Type) *lua.LTable { mt.RawSetString("__metatable", lua.LString("gopher-luar")) mt.RawSetString("methods", methods) + if process := config.Metatable; process != nil { + if newmt := process(L, vtype, mt, false); newmt != nil { + mt = newmt + } + } + config.regular[vtype] = mt + delete(config.processing, vtype) return mt } @@ -200,6 +221,39 @@ func getTypeMetatable(L *lua.LState, t reflect.Type) *lua.LTable { mt.RawSetString("__eq", L.NewFunction(typeEq)) mt.RawSetString("__metatable", lua.LString("gopher-luar")) + if process := config.Metatable; process != nil { + if newmt := process(L, t, mt, true); newmt != nil { + mt = newmt + } + } + config.types = mt return mt } + +func preprocessMetatables(L *lua.LState, c *Config, t reflect.Type) { + if !c.PreprocessMetatables || !doesGetMetatableHandle(t) { + return + } + getMetatable(L, t) + // also process the underlying type for containers + if k := t.Kind(); k == reflect.Ptr || k == reflect.Slice || + k == reflect.Array || k == reflect.Chan { + if doesGetMetatableHandle(t.Elem()) { + preprocessMetatables(L, c, t.Elem()) + } + } else if k == reflect.Map { + if doesGetMetatableHandle(t.Key()) { + preprocessMetatables(L, c, t.Key()) + } + if doesGetMetatableHandle(t.Elem()) { + preprocessMetatables(L, c, t.Elem()) + } + } +} + +func doesGetMetatableHandle(t reflect.Type) bool { + k := t.Kind() + return k == reflect.Struct || k == reflect.Ptr || k == reflect.Slice || + k == reflect.Array || k == reflect.Chan || k == reflect.Map +} diff --git a/config.go b/config.go index b3f4736..40f5461 100644 --- a/config.go +++ b/config.go @@ -3,7 +3,7 @@ package luar import ( "reflect" - "github.com/yuin/gopher-lua" + lua "github.com/yuin/gopher-lua" ) // Config is used to define luar behaviour for a particular *lua.LState. @@ -26,13 +26,46 @@ type Config struct { // - the method name and its name with a lowercase first letter MethodNames func(t reflect.Type, m reflect.Method) []string - regular map[reflect.Type]*lua.LTable - types *lua.LTable + // The metatable post-processor function. This gets run last, after the + // [default implementation] is done. If nil, skipped. You may use this to + // provide custom metamethods or the like for specific Go types. + // + // If the constructor parameter is true, the metatable is being created for + // the global constructor metatable; as such, this will only be called once. + // + // [default implementation]: https://github.com/layeh/gopher-luar/blob/master/cache.go#L92 + Metatable func(L *lua.LState, t reflect.Type, mt *lua.LTable, constructor bool) *lua.LTable + + // When true, all metatables are fully processed as soon as they are + // discovered. This increases memory use but enables operations that depend + // on complete metatable information (e.g., full annotation generation via + // the use of [Config.Metatable]). + // + // When false (default), metatables are processed only when needed. Calling + // [New] triggers initial processing, and further processing occurs only + // when a field, method, or call is accessed. If true, all referenced types + // are processed during New rather than on first use. + // + // This is not affected by [NewType] due to how luar is implemented. If you + // wish to do things based on the type of a [NewType], simply call [New] and + // unless you need it, discard the result. + PreprocessMetatables bool + + // overhead: to prevent recursion for [Config.PreprocessMetatables], we must + // store a map of types that are currently being processed. I wanted to use + // [Config.regular] with an empty table, but [getMetatable] puts the exact + // number of entries as the capacity, and unless I make a second switch, + // I can't think of a good way to emulate this behaviour. This will work for + // now, I think. FIXME. + processing map[reflect.Type]bool + regular map[reflect.Type]*lua.LTable + types *lua.LTable } func newConfig() *Config { return &Config{ - regular: make(map[reflect.Type]*lua.LTable), + processing: make(map[reflect.Type]bool), + regular: make(map[reflect.Type]*lua.LTable), } } diff --git a/preprocess_test.go b/preprocess_test.go new file mode 100644 index 0000000..531386b --- /dev/null +++ b/preprocess_test.go @@ -0,0 +1,110 @@ +package luar + +import ( + "reflect" + "testing" + + lua "github.com/yuin/gopher-lua" +) + +type Foo struct{ A Bar } // test field + +type Bar struct{ A int } // no-op + +func (b Bar) Baz() Baz { return Baz{A: struct{ A FooFoo }{}} } // test method + +type Baz struct{ A struct{ A FooFoo } } // test nested + +type FooFoo struct { + a FooBar + b FooBaz + + T1 int + T2 string + T3 map[string]int + T4 []int +} // b is hidden as field, visible as method + +func (f FooFoo) Baz(bar BarBaz) FooBaz { return f.b } + +type FooBar struct{ A int } + +type FooBaz struct{ A [][][]*[][]*[]*[][]*BarFoo } // ungodly depth achieved! + +type BarFoo struct{ A BarBar } + +type BarBar int + +type BarBaz interface{ Baz() FooBaz } + +func Test_preprocess(t *testing.T) { + t.Run("enabled", testPreprocess([]reflect.Type{ + reflect.TypeOf(Foo{}), + reflect.TypeOf(Bar{}), + reflect.TypeOf(Baz{}), + reflect.TypeOf(struct{ A FooFoo }{}), // nested + reflect.TypeOf(FooFoo{}), + // reflect.TypeOf(FooBar{}), // hidden, skipped + reflect.TypeOf(FooBaz{}), // hidden as field, visible as method + // reflect.TypeOf(BarBar(0)), // not of interest + reflect.TypeOf([]int{}), + reflect.TypeOf(map[string]int{}), + + // oh no... + reflect.TypeOf([][][]*[][]*[]*[][]*BarFoo{}), + reflect.TypeOf([][]*[][]*[]*[][]*BarFoo{}), + reflect.TypeOf([]*[][]*[]*[][]*BarFoo{}), + reflect.TypeOf(&[][]*[]*[][]*BarFoo{}), + reflect.TypeOf([][]*[]*[][]*BarFoo{}), + reflect.TypeOf([]*[]*[][]*BarFoo{}), + reflect.TypeOf(&[]*[][]*BarFoo{}), + reflect.TypeOf([]*[][]*BarFoo{}), + reflect.TypeOf(&[][]*BarFoo{}), + reflect.TypeOf([][]*BarFoo{}), + reflect.TypeOf([]*BarFoo{}), + reflect.TypeOf(&BarFoo{}), + reflect.TypeOf(BarFoo{}), + }, true, New)) + t.Run("disabled", testPreprocess([]reflect.Type{ + reflect.TypeOf(Foo{}), // no preprocessing -> only the original type + }, false, New)) + + t.Run("type_enabled", testPreprocess([]reflect.Type{reflect.TypeOf(Foo{})}, true, NewType)) + t.Run("type_disabled", testPreprocess([]reflect.Type{reflect.TypeOf(Foo{})}, false, NewType)) +} + +type newfn = func(L *lua.LState, value interface{}) lua.LValue + +func testPreprocess(expected []reflect.Type, preprocess bool, newfn newfn) func(t *testing.T) { + return func(t *testing.T) { + L := lua.NewState() + defer L.Close() + + got := map[string]bool{} + + GetConfig(L).PreprocessMetatables = preprocess + GetConfig(L).Metatable = func(L *lua.LState, t reflect.Type, mt *lua.LTable, constructor bool) *lua.LTable { + got[t.String()] = true + return mt + } + + newfn(L, Foo{}) + + expectedM := map[string]bool{} + for _, v := range expected { + expectedM[v.String()] = true + } + for name := range expectedM { + if _, ok := got[name]; ok { + t.Logf("processed %s", name) + } else { + t.Errorf("expected %s to be processed", name) + } + } + for name := range got { + if _, ok := expectedM[name]; !ok { + t.Errorf("did not expect %s to be processed", name) + } + } + } +}