package model import ( "testing" ) func newTestRepo() *Repository { return &Repository{ models: []Model{ { ID: "model-1", Name: "YOLOv8", Description: "Object detection model", TaskType: "object_detection", SupportedHardware: []string{"KL720", "KL730"}, }, { ID: "model-2", Name: "ResNet", Description: "Classification model", TaskType: "classification", SupportedHardware: []string{"KL720"}, }, { ID: "custom-1", Name: "My Custom Model", TaskType: "object_detection", IsCustom: true, }, }, } } func TestRepository_List(t *testing.T) { repo := newTestRepo() tests := []struct { name string filter ModelFilter expectedCount int }{ {"no filter", ModelFilter{}, 3}, {"filter by task type", ModelFilter{TaskType: "object_detection"}, 2}, {"filter by hardware", ModelFilter{Hardware: "KL730"}, 1}, {"filter by query", ModelFilter{Query: "YOLO"}, 1}, {"query case insensitive", ModelFilter{Query: "resnet"}, 1}, {"no matches", ModelFilter{TaskType: "segmentation"}, 0}, {"combined filters", ModelFilter{TaskType: "object_detection", Query: "YOLO"}, 1}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { results, count := repo.List(tt.filter) if count != tt.expectedCount { t.Errorf("List() count = %d, want %d", count, tt.expectedCount) } if len(results) != tt.expectedCount { t.Errorf("List() len(results) = %d, want %d", len(results), tt.expectedCount) } }) } } func TestRepository_GetByID(t *testing.T) { repo := newTestRepo() tests := []struct { name string id string wantErr bool }{ {"existing model", "model-1", false}, {"another existing", "model-2", false}, {"non-existing", "model-999", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { m, err := repo.GetByID(tt.id) if (err != nil) != tt.wantErr { t.Errorf("GetByID() error = %v, wantErr %v", err, tt.wantErr) } if !tt.wantErr && m.ID != tt.id { t.Errorf("GetByID() ID = %s, want %s", m.ID, tt.id) } }) } } func TestRepository_Add(t *testing.T) { repo := &Repository{models: []Model{}} m := Model{ID: "new-1", Name: "New Model"} repo.Add(m) if repo.Count() != 1 { t.Errorf("Count() = %d, want 1", repo.Count()) } } func TestRepository_Remove(t *testing.T) { repo := newTestRepo() tests := []struct { name string id string wantErr bool }{ {"remove custom model", "custom-1", false}, {"cannot remove built-in", "model-1", true}, {"not found", "model-999", true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := repo.Remove(tt.id) if (err != nil) != tt.wantErr { t.Errorf("Remove() error = %v, wantErr %v", err, tt.wantErr) } }) } }