diff --git a/sdks/go/pkg/beam/util/starcgenx/starcgenx.go b/sdks/go/pkg/beam/util/starcgenx/starcgenx.go index b5cd0ddc8eb8..22d2be6e43f1 100644 --- a/sdks/go/pkg/beam/util/starcgenx/starcgenx.go +++ b/sdks/go/pkg/beam/util/starcgenx/starcgenx.go @@ -15,7 +15,7 @@ // Package starcgenx is a Static Analysis Type Assertion shim and Registration Code Generator // which provides an extractor to extract types from a package, in order to generate -// approprate shimsr a package so code can be generated for it. +// appropriate shims for a package so code can be generated for it. // // It's written for use by the starcgen tool, but separate to permit // alternative "go/importer" Importers for accessing types from imported packages. @@ -336,6 +336,7 @@ func (e *Extractor) isRequired(ident string, obj types.Object, idsRequired, idsF // or it's receiver type identifier needs to be in the filtered identifiers. if idsRequired[ident] { idsFound[ident] = true + e.Printf("isRequired found: %s\n", ident) return true } // Check if this is a function. @@ -347,10 +348,10 @@ func (e *Extractor) isRequired(ident string, obj types.Object, idsRequired, idsF if recv := sig.Recv(); recv != nil && graph.IsLifecycleMethod(ident) { // We don't want to care about pointers, so dereference to value type. t := recv.Type() - p, ok := t.(*types.Pointer) + p, ok := types.Unalias(t).(*types.Pointer) for ok { t = p.Elem() - p, ok = t.(*types.Pointer) + p, ok = types.Unalias(t).(*types.Pointer) } ts := types.TypeString(t, e.qualifier) e.Printf("recv %v has %v, ts: %s %s--- ", recv, sig, ts, ident) @@ -384,6 +385,8 @@ func (e *Extractor) fromObj(fset *token.FileSet, id *ast.Ident, obj types.Object ident = obj.Name() } if !e.isRequired(ident, obj, idsRequired, idsFound) { + e.Printf("%s: %q with package %q is not required \n", + fset.Position(id.Pos()), id.Name, pkg.Name()) return } @@ -391,7 +394,7 @@ func (e *Extractor) fromObj(fset *token.FileSet, id *ast.Ident, obj types.Object case *types.Var: // Vars are tricky since they could be anything, and anywhere (package scope, parameters, etc) // eg. Flags, or Field Tags, among others. - // I'm increasingly convinced that we should simply igonore vars. + // I'm increasingly convinced that we should simply ignore vars. // Do nothing for vars. case *types.Func: sig := obj.Type().(*types.Signature) @@ -405,10 +408,10 @@ func (e *Extractor) fromObj(fset *token.FileSet, id *ast.Ident, obj types.Object } // This must be a structural DoFn! We should generate a closure wrapper for it. t := recv.Type() - p, ok := t.(*types.Pointer) + p, ok := types.Unalias(t).(*types.Pointer) for ok { t = p.Elem() - p, ok = t.(*types.Pointer) + p, ok = types.Unalias(t).(*types.Pointer) } ts := types.TypeString(t, e.qualifier) mthdMap := e.wraps[ts] @@ -453,6 +456,10 @@ func (e *Extractor) extractType(ot *types.TypeName) { // A single level is safe since the code we're analysing imports it, // so we can assume the generated code can access it too. if ot.IsAlias() { + if t, ok := ot.Type().(*types.Alias); ok { + ot = t.Obj() + name = types.TypeString(t, e.qualifier) + } if t, ok := ot.Type().(*types.Named); ok { ot = t.Obj() name = types.TypeString(t, e.qualifier) @@ -461,7 +468,7 @@ func (e *Extractor) extractType(ot *types.TypeName) { // Only register non-universe types (eg. avoid `error` and similar) if pkg := ot.Pkg(); pkg != nil { path := pkg.Path() - e.imports[pkg.Path()] = struct{}{} + e.imports[path] = struct{}{} // Do not add universal types to be registered. if path == shimx.TypexImport { @@ -484,17 +491,17 @@ func (e *Extractor) extractFromContainer(t types.Type) types.Type { // Container types need to be iteratively unwrapped until we're at the base type, // so we can get the import if necessary. for { - if s, ok := t.(*types.Slice); ok { + if s, ok := types.Unalias(t).(*types.Slice); ok { t = s.Elem() continue } - if p, ok := t.(*types.Pointer); ok { + if p, ok := types.Unalias(t).(*types.Pointer); ok { t = p.Elem() continue } - if a, ok := t.(*types.Array); ok { + if a, ok := types.Unalias(t).(*types.Array); ok { t = a.Elem() continue } @@ -510,9 +517,17 @@ func (e *Extractor) extractFromTuple(tuple *types.Tuple) { t := e.extractFromContainer(s.Type()) // Here's where we ensure we register new imports. + if at, ok := t.(*types.Alias); ok { + if pkg := at.Obj().Pkg(); pkg != nil { + e.imports[pkg.Path()] = struct{}{} + } + } if t, ok := t.(*types.Named); ok { if pkg := t.Obj().Pkg(); pkg != nil { + e.Printf("extractType: adding import path %q for %v\n", pkg.Path(), t) e.imports[pkg.Path()] = struct{}{} + } else { + e.Printf("extractType: %v has no package to import\n", t) } e.extractType(t.Obj()) } @@ -683,7 +698,7 @@ func (e *Extractor) makeEmitter(sig *types.Signature) (shimx.Emitter, bool) { // makeInput checks if the given signature is an iterator or not, and if so, // returns a shimx.Input struct for the signature for use by the code -// generator. The canonical check for an iterater signature is in the +// generator. The canonical check for an iterator signature is in the // funcx.UnfoldIter function which uses the reflect library, // and this logic is replicated here. func (e *Extractor) makeInput(sig *types.Signature) (shimx.Input, bool) { @@ -692,13 +707,13 @@ func (e *Extractor) makeInput(sig *types.Signature) (shimx.Input, bool) { return shimx.Input{}, false } // Iterators must return a bool. - if b, ok := r.At(0).Type().(*types.Basic); !ok || b.Kind() != types.Bool { + if b, ok := types.Unalias(r.At(0).Type()).(*types.Basic); !ok || b.Kind() != types.Bool { return shimx.Input{}, false } p := sig.Params() for i := 0; i < p.Len(); i++ { // All params for iterators must be pointers. - if _, ok := p.At(i).Type().(*types.Pointer); !ok { + if _, ok := types.Unalias(p.At(i).Type()).(*types.Pointer); !ok { return shimx.Input{}, false } }