Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add types.Unalias to types assertions and types switches to get an underlying type instead of types.Alias #33868

Merged
merged 11 commits into from
Feb 19, 2025
Merged
44 changes: 30 additions & 14 deletions sdks/go/pkg/beam/util/starcgenx/starcgenx.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

// Package starcgenx is a Static Analysis Type Assertion shim and Registration Code Generator
// which provides an extractor to extract types from a package, in order to generate
// approprate shimsr a package so code can be generated for it.
// appropriate shims for a package so code can be generated for it.
//
// It's written for use by the starcgen tool, but separate to permit
// alternative "go/importer" Importers for accessing types from imported packages.
Expand Down Expand Up @@ -336,6 +336,7 @@ func (e *Extractor) isRequired(ident string, obj types.Object, idsRequired, idsF
// or it's receiver type identifier needs to be in the filtered identifiers.
if idsRequired[ident] {
idsFound[ident] = true
e.Printf("isRequired found: %s\n", ident)
return true
}
// Check if this is a function.
Expand All @@ -347,10 +348,10 @@ func (e *Extractor) isRequired(ident string, obj types.Object, idsRequired, idsF
if recv := sig.Recv(); recv != nil && graph.IsLifecycleMethod(ident) {
// We don't want to care about pointers, so dereference to value type.
t := recv.Type()
p, ok := t.(*types.Pointer)
p, ok := types.Unalias(t).(*types.Pointer)
for ok {
t = p.Elem()
p, ok = t.(*types.Pointer)
p, ok = types.Unalias(t).(*types.Pointer)
}
ts := types.TypeString(t, e.qualifier)
e.Printf("recv %v has %v, ts: %s %s--- ", recv, sig, ts, ident)
Expand Down Expand Up @@ -384,14 +385,16 @@ func (e *Extractor) fromObj(fset *token.FileSet, id *ast.Ident, obj types.Object
ident = obj.Name()
}
if !e.isRequired(ident, obj, idsRequired, idsFound) {
e.Printf("%s: %q with package %q is not required \n",
fset.Position(id.Pos()), id.Name, pkg.Name())
return
}

switch ot := obj.(type) {
case *types.Var:
// Vars are tricky since they could be anything, and anywhere (package scope, parameters, etc)
// eg. Flags, or Field Tags, among others.
// I'm increasingly convinced that we should simply igonore vars.
// I'm increasingly convinced that we should simply ignore vars.
// Do nothing for vars.
case *types.Func:
sig := obj.Type().(*types.Signature)
Expand All @@ -405,10 +408,10 @@ func (e *Extractor) fromObj(fset *token.FileSet, id *ast.Ident, obj types.Object
}
// This must be a structural DoFn! We should generate a closure wrapper for it.
t := recv.Type()
p, ok := t.(*types.Pointer)
p, ok := types.Unalias(t).(*types.Pointer)
for ok {
t = p.Elem()
p, ok = t.(*types.Pointer)
p, ok = types.Unalias(t).(*types.Pointer)
}
ts := types.TypeString(t, e.qualifier)
mthdMap := e.wraps[ts]
Expand Down Expand Up @@ -453,6 +456,10 @@ func (e *Extractor) extractType(ot *types.TypeName) {
// A single level is safe since the code we're analysing imports it,
// so we can assume the generated code can access it too.
if ot.IsAlias() {
if t, ok := ot.Type().(*types.Alias); ok {
ot = t.Obj()
name = types.TypeString(t, e.qualifier)
}
if t, ok := ot.Type().(*types.Named); ok {
ot = t.Obj()
name = types.TypeString(t, e.qualifier)
Expand All @@ -461,7 +468,7 @@ func (e *Extractor) extractType(ot *types.TypeName) {
// Only register non-universe types (eg. avoid `error` and similar)
if pkg := ot.Pkg(); pkg != nil {
path := pkg.Path()
e.imports[pkg.Path()] = struct{}{}
e.imports[path] = struct{}{}

// Do not add universal types to be registered.
if path == shimx.TypexImport {
Expand All @@ -484,17 +491,17 @@ func (e *Extractor) extractFromContainer(t types.Type) types.Type {
// Container types need to be iteratively unwrapped until we're at the base type,
// so we can get the import if necessary.
for {
if s, ok := t.(*types.Slice); ok {
if s, ok := types.Unalias(t).(*types.Slice); ok {
t = s.Elem()
continue
}

if p, ok := t.(*types.Pointer); ok {
if p, ok := types.Unalias(t).(*types.Pointer); ok {
t = p.Elem()
continue
}

if a, ok := t.(*types.Array); ok {
if a, ok := types.Unalias(t).(*types.Array); ok {
t = a.Elem()
continue
}
Expand All @@ -510,9 +517,18 @@ func (e *Extractor) extractFromTuple(tuple *types.Tuple) {
t := e.extractFromContainer(s.Type())

// Here's where we ensure we register new imports.
if at, ok := t.(*types.Alias); ok {
e.Printf("extractFromTuple: %v is an alias - RHS %T\n", at, at.Rhs())
if pkg := at.Obj().Pkg(); pkg != nil {
e.imports[pkg.Path()] = struct{}{}
}
}
if t, ok := t.(*types.Named); ok {
if pkg := t.Obj().Pkg(); pkg != nil {
e.Printf("extractType: adding import path %q for %v\n", pkg.Path(), t)
if pkg := t.Obj().Pkg(); pkg != nil {
e.imports[pkg.Path()] = struct{}{}
} else {
e.Printf("extractType: %v has no package to import\n", t)
}
e.extractType(t.Obj())
}
Expand Down Expand Up @@ -683,7 +699,7 @@ func (e *Extractor) makeEmitter(sig *types.Signature) (shimx.Emitter, bool) {

// makeInput checks if the given signature is an iterator or not, and if so,
// returns a shimx.Input struct for the signature for use by the code
// generator. The canonical check for an iterater signature is in the
// generator. The canonical check for an iterator signature is in the
// funcx.UnfoldIter function which uses the reflect library,
// and this logic is replicated here.
func (e *Extractor) makeInput(sig *types.Signature) (shimx.Input, bool) {
Expand All @@ -692,13 +708,13 @@ func (e *Extractor) makeInput(sig *types.Signature) (shimx.Input, bool) {
return shimx.Input{}, false
}
// Iterators must return a bool.
if b, ok := r.At(0).Type().(*types.Basic); !ok || b.Kind() != types.Bool {
if b, ok := types.Unalias(r.At(0).Type()).(*types.Basic); !ok || b.Kind() != types.Bool {
return shimx.Input{}, false
}
p := sig.Params()
for i := 0; i < p.Len(); i++ {
// All params for iterators must be pointers.
if _, ok := p.At(i).Type().(*types.Pointer); !ok {
if _, ok := types.Unalias(p.At(i).Type()).(*types.Pointer); !ok {
return shimx.Input{}, false
}
}
Expand Down
Loading