feat: many2many relations check
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"gormlint/foreignKeyCheck"
|
||||
"gormlint/nullSafetyCheck"
|
||||
"gormlint/referencesCheck"
|
||||
"gormlint/relationsCheck"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -12,5 +13,6 @@ func main() {
|
||||
nullSafetyCheck.NullSafetyAnalyzer,
|
||||
referencesCheck.ReferenceAnalyzer,
|
||||
foreignKeyCheck.ForeignKeyCheck,
|
||||
relationsCheck.RelationsAnalyzer,
|
||||
)
|
||||
}
|
||||
|
||||
100
common/finders.go
Normal file
100
common/finders.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package common
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func GetModelField(model *Model, fieldName string) *Field {
|
||||
field, fieldExists := model.Fields[fieldName]
|
||||
if fieldExists {
|
||||
return &field
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func GetModelFromType(modelType ast.Expr, models map[string]Model) *Model {
|
||||
baseType := ResolveBaseType(modelType)
|
||||
if baseType != nil {
|
||||
return GetRelatedModel(*baseType, models)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func GetRelatedModel(modelName string, models map[string]Model) *Model {
|
||||
model, modelExists := models[modelName]
|
||||
if modelExists {
|
||||
return &model
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func FindParamValue(paramName string, params []string) *string {
|
||||
for _, rawParam := range params {
|
||||
pair := strings.Split(rawParam, ":")
|
||||
if len(pair) < 2 {
|
||||
return nil
|
||||
}
|
||||
if strings.ToLower(pair[0]) == strings.ToLower(paramName) {
|
||||
return &pair[1]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func FindModelParam(paramName string, model Model) *Param {
|
||||
for _, field := range model.Fields {
|
||||
for _, param := range field.Params {
|
||||
pair := strings.Split(param, ":")
|
||||
if len(pair) < 2 {
|
||||
return nil
|
||||
}
|
||||
if strings.ToLower(pair[0]) == strings.ToLower(paramName) {
|
||||
return &Param{
|
||||
Name: pair[0],
|
||||
Value: pair[1],
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func FindReferencesInM2M(m2mReference Field, relatedModel Model) *Field {
|
||||
/* Find `references` field in m2m relation */
|
||||
referencesTagValue := FindParamValue("references", m2mReference.Params)
|
||||
if referencesTagValue != nil {
|
||||
return GetModelField(&relatedModel, *referencesTagValue)
|
||||
} else {
|
||||
for _, field := range relatedModel.Fields {
|
||||
for _, opt := range field.Options {
|
||||
if opt == "primaryKey" {
|
||||
return &field
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, field := range relatedModel.Fields {
|
||||
if strings.ToLower(field.Name) == "id" {
|
||||
return &field
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func FindBackReferenceInM2M(relationName string, relatedModel Model) *Field {
|
||||
for _, field := range relatedModel.Fields {
|
||||
m2mRelation := field.GetParam("many2many")
|
||||
if m2mRelation != nil {
|
||||
if m2mRelation.Value == relationName {
|
||||
return &field
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
//func findForeignKey()
|
||||
11
common/isSlice.go
Normal file
11
common/isSlice.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package common
|
||||
|
||||
import "go/ast"
|
||||
|
||||
func IsSlice(expr ast.Expr) bool {
|
||||
arrayType, ok := expr.(*ast.ArrayType)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return arrayType.Len == nil
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package common
|
||||
import (
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Field struct {
|
||||
@@ -21,3 +22,50 @@ type Model struct {
|
||||
Pos token.Pos
|
||||
Comment string
|
||||
}
|
||||
|
||||
type Param struct {
|
||||
Name string
|
||||
Value string
|
||||
}
|
||||
|
||||
func (model *Model) GetParam(name string) *Param {
|
||||
for _, field := range model.Fields {
|
||||
for _, param := range field.Params {
|
||||
pair := strings.SplitN(param, ":", 2)
|
||||
if len(pair) != 2 {
|
||||
return nil
|
||||
}
|
||||
if strings.ToLower(pair[0]) == strings.ToLower(name) {
|
||||
return &Param{
|
||||
Name: pair[0],
|
||||
Value: pair[1],
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (model *Model) HasParam(name string) bool {
|
||||
return model.GetParam(name) != nil
|
||||
}
|
||||
|
||||
func (field *Field) HasParam(name string) bool {
|
||||
return field.GetParam(name) != nil
|
||||
}
|
||||
|
||||
func (field *Field) GetParam(name string) *Param {
|
||||
for _, param := range field.Params {
|
||||
pair := strings.SplitN(param, ":", 2)
|
||||
if len(pair) != 2 {
|
||||
return nil
|
||||
}
|
||||
if strings.ToLower(pair[0]) == strings.ToLower(name) {
|
||||
return &Param{
|
||||
Name: pair[0],
|
||||
Value: pair[1],
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
89
relationsCheck/relationsCheck.go
Normal file
89
relationsCheck/relationsCheck.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package relationsCheck
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"golang.org/x/tools/go/analysis"
|
||||
"gormlint/common"
|
||||
"slices"
|
||||
)
|
||||
|
||||
// RelationsAnalyzer todo: add URL for rule
|
||||
var RelationsAnalyzer = &analysis.Analyzer{
|
||||
Name: "GormReferencesCheck",
|
||||
Doc: "report about invalid references in models",
|
||||
Run: run,
|
||||
}
|
||||
|
||||
func CheckTypesOfM2M(pass *analysis.Pass, modelName string, relatedModelName string, relationName string, reference common.Field, backReference common.Field) {
|
||||
if !common.IsSlice(reference.Type) {
|
||||
pass.Reportf(reference.Pos, "M2M relation `%s` with bad type `%s` (should be a slice)", relationName, reference.Type)
|
||||
return
|
||||
}
|
||||
if !common.IsSlice(backReference.Type) {
|
||||
pass.Reportf(backReference.Pos, "M2M relation `%s` with bad type `%s` (should be a slice)", relationName, backReference.Type)
|
||||
return
|
||||
}
|
||||
|
||||
referenceBaseType := common.ResolveBaseType(reference.Type)
|
||||
if referenceBaseType == nil {
|
||||
pass.Reportf(reference.Pos, "Failed to resolve field type: `%s`", reference.Type)
|
||||
return
|
||||
}
|
||||
backReferenceBaseType := common.ResolveBaseType(backReference.Type)
|
||||
if backReferenceBaseType == nil {
|
||||
pass.Reportf(reference.Pos, "Failed to resolve type: `%s`", reference.Type)
|
||||
return
|
||||
}
|
||||
|
||||
if *backReferenceBaseType != modelName {
|
||||
pass.Reportf(backReference.Pos, "Invalid type `%s` in M2M relation (use []*%s or self-reference)", *backReferenceBaseType, modelName)
|
||||
return
|
||||
}
|
||||
|
||||
if *referenceBaseType != relatedModelName {
|
||||
pass.Reportf(reference.Pos, "Invalid type `%s` in M2M relation (use []*%s or self-reference)", *referenceBaseType, relatedModelName)
|
||||
}
|
||||
}
|
||||
|
||||
func CheckMany2Many(pass *analysis.Pass, models map[string]common.Model) {
|
||||
// TODO: unexpected duplicated relations
|
||||
var knownModels []string
|
||||
for _, model := range models {
|
||||
for _, field := range model.Fields {
|
||||
m2mRelation := field.GetParam("many2many")
|
||||
if m2mRelation != nil {
|
||||
relatedModel := common.GetModelFromType(field.Type, models)
|
||||
if relatedModel == nil {
|
||||
pass.Reportf(field.Pos, "Failed to resolve related model type")
|
||||
return
|
||||
}
|
||||
|
||||
backReference := common.FindBackReferenceInM2M(m2mRelation.Value, *relatedModel)
|
||||
if backReference != nil {
|
||||
if slices.Contains(knownModels, relatedModel.Name) {
|
||||
continue
|
||||
} else {
|
||||
knownModels = append(knownModels, model.Name)
|
||||
knownModels = append(knownModels, relatedModel.Name)
|
||||
}
|
||||
CheckTypesOfM2M(pass, model.Name, relatedModel.Name, m2mRelation.Value, field, *backReference)
|
||||
// TODO: check foreign key and references
|
||||
fmt.Printf("Found M2M relation between \"%s\" and \"%s\"\n", model.Name, relatedModel.Name)
|
||||
} else {
|
||||
// Here you can forbid M2M relations without back-reference
|
||||
// TODO: process m2m without backref
|
||||
}
|
||||
} else {
|
||||
// TODO: check [] and process m:1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func run(pass *analysis.Pass) (any, error) {
|
||||
models := make(map[string]common.Model)
|
||||
common.ParseModels(pass, &models)
|
||||
CheckMany2Many(pass, models)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
12
tests/relations_check_test.go
Normal file
12
tests/relations_check_test.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"golang.org/x/tools/go/analysis/analysistest"
|
||||
"gormlint/relationsCheck"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRelationsCheck(t *testing.T) {
|
||||
t.Parallel()
|
||||
analysistest.Run(t, analysistest.TestData(), relationsCheck.RelationsAnalyzer, "relations_check")
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
package references_check
|
||||
|
||||
// TODO: add test with annotations on back-references
|
||||
|
||||
type WorkArea struct {
|
||||
Id uint `gorm:"primaryKey"`
|
||||
Workshop Workshop `gorm:"foreignKey:WorkshopId;references:Id;"`
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package references_check
|
||||
|
||||
// TODO: add test with annotations on back-references
|
||||
|
||||
type User struct {
|
||||
Name string
|
||||
CompanyID string
|
||||
|
||||
16
tests/testdata/src/relations_check/negative.go
vendored
Normal file
16
tests/testdata/src/relations_check/negative.go
vendored
Normal file
@@ -0,0 +1,16 @@
|
||||
package relations_check
|
||||
|
||||
type Library struct {
|
||||
Id uint `gorm:"primaryKey"`
|
||||
Books []*Book `gorm:"many2many:library_book;"`
|
||||
}
|
||||
|
||||
type Book struct {
|
||||
Id uint `gorm:"primaryKey"`
|
||||
Libraries []*Library `gorm:"many2many:library_book;"`
|
||||
}
|
||||
|
||||
type Employee struct {
|
||||
Id uint `gorm:"primaryKey"`
|
||||
Subordinates []*Employee `gorm:"many2many:employee_subordinates;"` // self-reference
|
||||
}
|
||||
21
tests/testdata/src/relations_check/positive.go
vendored
Normal file
21
tests/testdata/src/relations_check/positive.go
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
package relations_check
|
||||
|
||||
type Student struct {
|
||||
Id uint `gorm:"primaryKey"`
|
||||
Courses []Course `gorm:"many2many:student_courses;"`
|
||||
}
|
||||
|
||||
type Course struct {
|
||||
Id uint `gorm:"primaryKey"`
|
||||
Students []Course `gorm:"many2many:student_courses;"` // want "Invalid type `Course` in M2M relation \\(use \\[\\]\\*Student or self-reference\\)"
|
||||
}
|
||||
|
||||
type Author struct {
|
||||
Id uint `gorm:"primaryKey"`
|
||||
Articles []Article `gorm:"many2many:author_articles;"`
|
||||
}
|
||||
|
||||
type Article struct {
|
||||
Id uint `gorm:"primaryKey"`
|
||||
Authors Author `gorm:"many2many:author_articles;"` // want "M2M relation `author_articles` with bad type `Author` \\(should be a slice\\)"
|
||||
}
|
||||
Reference in New Issue
Block a user