diff --git a/cmd/gormlint/main.go b/cmd/gormlint/main.go index a78e4e5..e2274b6 100644 --- a/cmd/gormlint/main.go +++ b/cmd/gormlint/main.go @@ -5,6 +5,7 @@ import ( "gormlint/foreignKeyCheck" "gormlint/nullSafetyCheck" "gormlint/referencesCheck" + "gormlint/relationsCheck" ) func main() { @@ -12,5 +13,6 @@ func main() { nullSafetyCheck.NullSafetyAnalyzer, referencesCheck.ReferenceAnalyzer, foreignKeyCheck.ForeignKeyCheck, + relationsCheck.RelationsAnalyzer, ) } diff --git a/common/finders.go b/common/finders.go new file mode 100644 index 0000000..f9fd4ac --- /dev/null +++ b/common/finders.go @@ -0,0 +1,100 @@ +package common + +import ( + "go/ast" + "strings" +) + +func GetModelField(model *Model, fieldName string) *Field { + field, fieldExists := model.Fields[fieldName] + if fieldExists { + return &field + } else { + return nil + } +} + +func GetModelFromType(modelType ast.Expr, models map[string]Model) *Model { + baseType := ResolveBaseType(modelType) + if baseType != nil { + return GetRelatedModel(*baseType, models) + } else { + return nil + } +} + +func GetRelatedModel(modelName string, models map[string]Model) *Model { + model, modelExists := models[modelName] + if modelExists { + return &model + } else { + return nil + } +} + +func FindParamValue(paramName string, params []string) *string { + for _, rawParam := range params { + pair := strings.Split(rawParam, ":") + if len(pair) < 2 { + return nil + } + if strings.ToLower(pair[0]) == strings.ToLower(paramName) { + return &pair[1] + } + } + return nil +} + +func FindModelParam(paramName string, model Model) *Param { + for _, field := range model.Fields { + for _, param := range field.Params { + pair := strings.Split(param, ":") + if len(pair) < 2 { + return nil + } + if strings.ToLower(pair[0]) == strings.ToLower(paramName) { + return &Param{ + Name: pair[0], + Value: pair[1], + } + } + } + } + return nil +} + +func FindReferencesInM2M(m2mReference Field, relatedModel Model) *Field { + /* Find `references` field in m2m relation */ + referencesTagValue := FindParamValue("references", m2mReference.Params) + if referencesTagValue != nil { + return GetModelField(&relatedModel, *referencesTagValue) + } else { + for _, field := range relatedModel.Fields { + for _, opt := range field.Options { + if opt == "primaryKey" { + return &field + } + } + } + for _, field := range relatedModel.Fields { + if strings.ToLower(field.Name) == "id" { + return &field + } + } + return nil + } +} + +func FindBackReferenceInM2M(relationName string, relatedModel Model) *Field { + for _, field := range relatedModel.Fields { + m2mRelation := field.GetParam("many2many") + if m2mRelation != nil { + if m2mRelation.Value == relationName { + return &field + } + } + } + return nil +} + +//func findForeignKey() diff --git a/common/isSlice.go b/common/isSlice.go new file mode 100644 index 0000000..87958c2 --- /dev/null +++ b/common/isSlice.go @@ -0,0 +1,11 @@ +package common + +import "go/ast" + +func IsSlice(expr ast.Expr) bool { + arrayType, ok := expr.(*ast.ArrayType) + if !ok { + return false + } + return arrayType.Len == nil +} diff --git a/common/model.go b/common/model.go index dd4175b..8a67346 100644 --- a/common/model.go +++ b/common/model.go @@ -3,6 +3,7 @@ package common import ( "go/ast" "go/token" + "strings" ) type Field struct { @@ -21,3 +22,50 @@ type Model struct { Pos token.Pos Comment string } + +type Param struct { + Name string + Value string +} + +func (model *Model) GetParam(name string) *Param { + for _, field := range model.Fields { + for _, param := range field.Params { + pair := strings.SplitN(param, ":", 2) + if len(pair) != 2 { + return nil + } + if strings.ToLower(pair[0]) == strings.ToLower(name) { + return &Param{ + Name: pair[0], + Value: pair[1], + } + } + } + } + return nil +} + +func (model *Model) HasParam(name string) bool { + return model.GetParam(name) != nil +} + +func (field *Field) HasParam(name string) bool { + return field.GetParam(name) != nil +} + +func (field *Field) GetParam(name string) *Param { + for _, param := range field.Params { + pair := strings.SplitN(param, ":", 2) + if len(pair) != 2 { + return nil + } + if strings.ToLower(pair[0]) == strings.ToLower(name) { + return &Param{ + Name: pair[0], + Value: pair[1], + } + } + } + return nil +} diff --git a/relationsCheck/relationsCheck.go b/relationsCheck/relationsCheck.go new file mode 100644 index 0000000..cad6afa --- /dev/null +++ b/relationsCheck/relationsCheck.go @@ -0,0 +1,89 @@ +package relationsCheck + +import ( + "fmt" + "golang.org/x/tools/go/analysis" + "gormlint/common" + "slices" +) + +// RelationsAnalyzer todo: add URL for rule +var RelationsAnalyzer = &analysis.Analyzer{ + Name: "GormReferencesCheck", + Doc: "report about invalid references in models", + Run: run, +} + +func CheckTypesOfM2M(pass *analysis.Pass, modelName string, relatedModelName string, relationName string, reference common.Field, backReference common.Field) { + if !common.IsSlice(reference.Type) { + pass.Reportf(reference.Pos, "M2M relation `%s` with bad type `%s` (should be a slice)", relationName, reference.Type) + return + } + if !common.IsSlice(backReference.Type) { + pass.Reportf(backReference.Pos, "M2M relation `%s` with bad type `%s` (should be a slice)", relationName, backReference.Type) + return + } + + referenceBaseType := common.ResolveBaseType(reference.Type) + if referenceBaseType == nil { + pass.Reportf(reference.Pos, "Failed to resolve field type: `%s`", reference.Type) + return + } + backReferenceBaseType := common.ResolveBaseType(backReference.Type) + if backReferenceBaseType == nil { + pass.Reportf(reference.Pos, "Failed to resolve type: `%s`", reference.Type) + return + } + + if *backReferenceBaseType != modelName { + pass.Reportf(backReference.Pos, "Invalid type `%s` in M2M relation (use []*%s or self-reference)", *backReferenceBaseType, modelName) + return + } + + if *referenceBaseType != relatedModelName { + pass.Reportf(reference.Pos, "Invalid type `%s` in M2M relation (use []*%s or self-reference)", *referenceBaseType, relatedModelName) + } +} + +func CheckMany2Many(pass *analysis.Pass, models map[string]common.Model) { + // TODO: unexpected duplicated relations + var knownModels []string + for _, model := range models { + for _, field := range model.Fields { + m2mRelation := field.GetParam("many2many") + if m2mRelation != nil { + relatedModel := common.GetModelFromType(field.Type, models) + if relatedModel == nil { + pass.Reportf(field.Pos, "Failed to resolve related model type") + return + } + + backReference := common.FindBackReferenceInM2M(m2mRelation.Value, *relatedModel) + if backReference != nil { + if slices.Contains(knownModels, relatedModel.Name) { + continue + } else { + knownModels = append(knownModels, model.Name) + knownModels = append(knownModels, relatedModel.Name) + } + CheckTypesOfM2M(pass, model.Name, relatedModel.Name, m2mRelation.Value, field, *backReference) + // TODO: check foreign key and references + fmt.Printf("Found M2M relation between \"%s\" and \"%s\"\n", model.Name, relatedModel.Name) + } else { + // Here you can forbid M2M relations without back-reference + // TODO: process m2m without backref + } + } else { + // TODO: check [] and process m:1 + } + } + } +} + +func run(pass *analysis.Pass) (any, error) { + models := make(map[string]common.Model) + common.ParseModels(pass, &models) + CheckMany2Many(pass, models) + + return nil, nil +} diff --git a/tests/relations_check_test.go b/tests/relations_check_test.go new file mode 100644 index 0000000..3cd8d1b --- /dev/null +++ b/tests/relations_check_test.go @@ -0,0 +1,12 @@ +package tests + +import ( + "golang.org/x/tools/go/analysis/analysistest" + "gormlint/relationsCheck" + "testing" +) + +func TestRelationsCheck(t *testing.T) { + t.Parallel() + analysistest.Run(t, analysistest.TestData(), relationsCheck.RelationsAnalyzer, "relations_check") +} diff --git a/tests/testdata/src/references_check/negative.go b/tests/testdata/src/references_check/negative.go index f3bd40d..d7d338b 100644 --- a/tests/testdata/src/references_check/negative.go +++ b/tests/testdata/src/references_check/negative.go @@ -1,5 +1,7 @@ package references_check +// TODO: add test with annotations on back-references + type WorkArea struct { Id uint `gorm:"primaryKey"` Workshop Workshop `gorm:"foreignKey:WorkshopId;references:Id;"` diff --git a/tests/testdata/src/references_check/positive.go b/tests/testdata/src/references_check/positive.go index 7efde8c..1a59dcb 100644 --- a/tests/testdata/src/references_check/positive.go +++ b/tests/testdata/src/references_check/positive.go @@ -1,5 +1,7 @@ package references_check +// TODO: add test with annotations on back-references + type User struct { Name string CompanyID string diff --git a/tests/testdata/src/relations_check/negative.go b/tests/testdata/src/relations_check/negative.go new file mode 100644 index 0000000..08665d6 --- /dev/null +++ b/tests/testdata/src/relations_check/negative.go @@ -0,0 +1,16 @@ +package relations_check + +type Library struct { + Id uint `gorm:"primaryKey"` + Books []*Book `gorm:"many2many:library_book;"` +} + +type Book struct { + Id uint `gorm:"primaryKey"` + Libraries []*Library `gorm:"many2many:library_book;"` +} + +type Employee struct { + Id uint `gorm:"primaryKey"` + Subordinates []*Employee `gorm:"many2many:employee_subordinates;"` // self-reference +} diff --git a/tests/testdata/src/relations_check/positive.go b/tests/testdata/src/relations_check/positive.go new file mode 100644 index 0000000..1de481a --- /dev/null +++ b/tests/testdata/src/relations_check/positive.go @@ -0,0 +1,21 @@ +package relations_check + +type Student struct { + Id uint `gorm:"primaryKey"` + Courses []Course `gorm:"many2many:student_courses;"` +} + +type Course struct { + Id uint `gorm:"primaryKey"` + Students []Course `gorm:"many2many:student_courses;"` // want "Invalid type `Course` in M2M relation \\(use \\[\\]\\*Student or self-reference\\)" +} + +type Author struct { + Id uint `gorm:"primaryKey"` + Articles []Article `gorm:"many2many:author_articles;"` +} + +type Article struct { + Id uint `gorm:"primaryKey"` + Authors Author `gorm:"many2many:author_articles;"` // want "M2M relation `author_articles` with bad type `Author` \\(should be a slice\\)" +}