From 80fa0d5279774a4eba8f381ab2a03ccf2b0dbb22 Mon Sep 17 00:00:00 2001 From: Levente Batuska Date: Sat, 4 Jan 2025 01:55:26 +0100 Subject: [PATCH] feat(Type): Scan method fixes for Result --- type/interfaces.go | 4 +- type/result.go | 216 ++++++++++++++++++++++++++------------------ type/result_test.go | 15 +++ 3 files changed, 144 insertions(+), 91 deletions(-) diff --git a/type/interfaces.go b/type/interfaces.go index e0f5683..cc71cc8 100644 --- a/type/interfaces.go +++ b/type/interfaces.go @@ -49,8 +49,8 @@ var ( _ ValueContainer = (*Optional[any])(nil) _ ValueContainer = (*Result[any])(nil) _ sql.Scanner = (*Optional[any])(nil) - // _ sql.Scanner = (*Result[any])(nil) - _ json.Marshaler = (*Optional[any])(nil) + _ sql.Scanner = (*Result[any])(nil) + _ json.Marshaler = (*Optional[any])(nil) // _ json.Marshaler = (*Result[any])(nil) _ json.Unmarshaler = (*Optional[any])(nil) // _ json.Unmarshaler = (*Result[any])(nil) diff --git a/type/result.go b/type/result.go index 6579c0f..62a1f2e 100644 --- a/type/result.go +++ b/type/result.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "fmt" + "reflect" "time" Assert "github.com/lbatuska/goutils/assert" @@ -139,118 +140,155 @@ func (res *Result[T]) Err() Optional[error] { } func (res *Result[T]) Scan(src interface{}) error { + Assert.NotNil(res) + e := fmt.Errorf("Unsupported type %T or differs from Result[%T], and the type doesn't implement sql.Scanner!", + src, res.value) + res.err = e // DB had a null value if src == nil { - res.err = errors.New("src was nil!") return nil } + // If T is a scanner - if scanner, ok := any(&res.value).(sql.Scanner); ok { + if scanner, ok := any(res.value).(sql.Scanner); ok { if err := scanner.Scan(src); err != nil { - res.err = err return err } res.err = nil return nil } - // We implement parsing for some builtin types - mismatchErr := fmt.Errorf("Type of src (%T) doesn't match type of Result[%T]!", src, res.value) - switch v := any(&res.value).(type) { + if scanres := res.scanBuiltin(src); scanres.IsSome() { + return scanres.Unwrap() + } + return e +} + +func (res *Result[T]) scanBuiltin(src interface{}) Optional[error] { + res.err = nil + // First handle the special cases where we allow conversion between types + // This is usually just parsing []byte into type + if scanres := res.scanTimeSpecial(src); scanres.IsSome() { + return scanres + } + if scanres := res.scanStringSpecial(src); scanres.IsSome() { + return scanres + } + + srcVal := reflect.ValueOf(src) + optType := reflect.TypeOf(res.value) + + optElemType := optType + if optElemType.Kind() == reflect.Pointer { + optElemType = optElemType.Elem() + } + + srcElemType := srcVal.Type() + if srcElemType.Kind() == reflect.Pointer { + srcElemType = srcElemType.Elem() + } + + if srcElemType != optElemType { + e := fmt.Errorf("Result[%T] (aka %T) differs from %T!", res.value, res.value, src) + res.err = e + return Some(e) + } + + if optType.Kind() == reflect.Pointer { + if srcVal.Kind() == reflect.Pointer { + res.value = srcVal.Interface().(T) + } else { + newPtr := reflect.New(optElemType) + newPtr.Elem().Set(srcVal) + res.value = newPtr.Interface().(T) + } + } else { + if srcVal.Kind() == reflect.Pointer { + res.value = srcVal.Elem().Interface().(T) + } else { + res.value = srcVal.Interface().(T) + } + } + res.err = nil + return Some[error](nil) +} + +func (res *Result[T]) scanStringSpecial(src interface{}) Optional[error] { + switch v := any(res.value).(type) { case *string: - if str, ok := src.(string); ok { - *v = str - res.err = nil - return nil - } - if b, ok := src.([]byte); ok { - *v = string(b) - res.err = nil - return nil - } - res.err = mismatchErr - return res.err - - case *int: - if s, ok := src.(int); ok { - *v = s - res.err = nil - return nil - } - res.err = mismatchErr - return res.err - case *int32: - if s, ok := src.(int32); ok { - *v = s - res.err = nil - return nil - } - res.err = mismatchErr - return res.err - case *int64: - if s, ok := src.(int64); ok { - *v = s - res.err = nil - return nil - } - res.err = mismatchErr - return res.err - - case *bool: switch s := src.(type) { - case bool: - *v = s - res.err = nil - return nil + case []byte: + *v = string(s) + goto ok + case *[]byte: + *v = string(*s) + goto ok } - res.err = mismatchErr - return res.err - - case *float64: + case string: switch s := src.(type) { - case float64: - *v = s - res.err = nil - return nil - case float32: - *v = float64(s) - res.err = nil - return nil - + case []byte: + reflect.ValueOf(&res.value).Elem().Set(reflect.ValueOf(string(s))) + goto ok + case *[]byte: + reflect.ValueOf(&res.value).Elem().Set(reflect.ValueOf(string(*s))) + goto ok } - res.err = mismatchErr - return res.err - - case *float32: - switch s := src.(type) { - case float32: - *v = s - res.err = nil - return nil - } - res.err = mismatchErr - return res.err + } + return None[error]() +ok: + res.err = nil + return Some[error](nil) +} +func (res *Result[T]) scanTimeSpecial(src interface{}) Optional[error] { + switch v := any(res.value).(type) { case *time.Time: - if t, ok := src.(time.Time); ok { - *v = t - res.err = nil - return nil - } - if b, ok := src.([]byte); ok { - parsedTime, err := time.Parse(time.RFC3339, string(b)) // or use other formats as necessary + switch t := src.(type) { + case []byte: + parsedTime, err := time.Parse(time.RFC3339, string(t)) if err == nil { *v = parsedTime - res.err = nil - return nil + goto ok + } else { + res.err = err + return Some(err) + } + case *[]byte: + parsedTime, err := time.Parse(time.RFC3339, string(*t)) + if err == nil { + *v = parsedTime + goto ok + } else { + res.err = err + return Some(err) } } - res.err = mismatchErr - return res.err + case time.Time: + switch t := src.(type) { + case []byte: + parsedTime, err := time.Parse(time.RFC3339, string(t)) + if err == nil { + reflect.ValueOf(&res.value).Elem().Set(reflect.ValueOf(parsedTime)) + goto ok + } else { + res.err = err + return Some(err) + } + case *[]byte: + parsedTime, err := time.Parse(time.RFC3339, string(*t)) + if err == nil { + reflect.ValueOf(&res.value).Elem().Set(reflect.ValueOf(parsedTime)) + goto ok + } else { + res.err = err + return Some(err) + } + } + } - // We couldnt parse the value - err := fmt.Errorf("Unsupported type %T, and the type doesn't implement sql.Scanner!", src) - res.err = err - return err + return None[error]() +ok: + res.err = nil + return Some[error](nil) } diff --git a/type/result_test.go b/type/result_test.go index 30433b7..3d06875 100644 --- a/type/result_test.go +++ b/type/result_test.go @@ -15,6 +15,21 @@ var ( nilResult = (*Result[int])(nil) ) +func Test_resultScan(t *testing.T) { + a := Ok("Not A!") + b := Ok(0) + c := Ok(0) + d := Err[*string](errors.New("")) + a.Scan("a") + b.Scan(1) + c.Scan("c") + d.Scan(Ptr("d")) + Testing.AssertEqual(t, a.Unwrap(), "a") + Testing.AssertEqual(t, b.Unwrap(), 1) + Testing.AssertError(t, c.UnwrapErr()) + Testing.AssertEqual(t, *d.Unwrap(), "d") +} + func Test_resultConstructors(t *testing.T) { err := errors.New("some error") w := Err[int](err)