From 864953076a318d938ea44ccf31cbe13801a800fe Mon Sep 17 00:00:00 2001 From: Levente Batuska Date: Wed, 1 Jan 2025 22:24:59 +0100 Subject: [PATCH] fix(Type): Scan method fixes --- type/interfaces.go | 4 + type/optional.go | 205 ++++++++++++++++++++++++------------------ type/optional_test.go | 14 +++ 3 files changed, 136 insertions(+), 87 deletions(-) diff --git a/type/interfaces.go b/type/interfaces.go index f33962e..4192c8a 100644 --- a/type/interfaces.go +++ b/type/interfaces.go @@ -1,5 +1,7 @@ package Type +import "database/sql" + // Created to abstract over Is_some and Is_ok type ValueContainer interface { HasValue() bool @@ -43,4 +45,6 @@ var ( _ Resulter[any] = (*Result[any])(nil) _ ValueContainer = (*Optional[any])(nil) _ ValueContainer = (*Result[any])(nil) + _ sql.Scanner = (*Optional[any])(nil) + // _ sql.Scanner = (*Result[any])(nil) ) diff --git a/type/optional.go b/type/optional.go index 8c83a6e..398ed4c 100644 --- a/type/optional.go +++ b/type/optional.go @@ -3,6 +3,7 @@ package Type import ( "database/sql" "fmt" + "reflect" "time" Assert "github.com/lbatuska/goutils/assert" @@ -114,111 +115,141 @@ func (opt *Optional[T]) OkOrElse(f func() error) Result[T] { } func (opt *Optional[T]) Scan(src interface{}) error { + Assert.NotNil(opt) + opt.present = false // DB had a null value if src == nil { - opt.present = false return nil } + // If T is a scanner - if scanner, ok := any(&opt.value).(sql.Scanner); ok { + if scanner, ok := any(opt.value).(sql.Scanner); ok { if err := scanner.Scan(src); err != nil { - opt.present = false return err } opt.present = true return nil } - // We implement parsing for some builtin types + + if scanres := opt.scanBuiltin(src); scanres.IsSome() { + return scanres.Unwrap() + } + + return fmt.Errorf("Unsupported type %T or differs from Optional[%T], and the type doesn't implement sql.Scanner!", src, opt.value) +} + +func (opt *Optional[T]) scanBuiltin(src interface{}) Optional[error] { opt.present = false - switch v := any(&opt.value).(type) { + // First handle the special cases where we allow conversion between types + // This is usually just parsing []byte into type + if scanres := opt.scanTimeSpecial(src); scanres.IsSome() { + return scanres + } + if scanres := opt.scanStringSpecial(src); scanres.IsSome() { + return scanres + } - case *string: - if str, ok := src.(string); ok { - *v = str - opt.present = true - return nil - } - if b, ok := src.([]byte); ok { - *v = string(b) - opt.present = true - return nil - } + srcVal := reflect.ValueOf(src) + optType := reflect.TypeOf(opt.value) - case *int: - if s, ok := src.(int); ok { - *v = s - opt.present = true - return nil - } - case *int32: - if s, ok := src.(int32); ok { - *v = s - opt.present = true - return nil - } - case *int64: - if s, ok := src.(int64); ok { - *v = s - opt.present = true - return nil - } + optElemType := optType + if optElemType.Kind() == reflect.Pointer { + optElemType = optElemType.Elem() + } - case *bool: - switch s := src.(type) { - case bool: - *v = s - opt.present = true - return nil - // We could technically allow this, however we try to avoid implicit conversions to ensure type safety. - // case string: - // if s == "1" || s == "true" || s == "t" { - // *v = true - // opt.present = true - // return nil - // } - // if s == "0" || s == "false" || s == "f" { - // *v = false - // opt.present = true - // return nil - // } - } + srcElemType := srcVal.Type() + if srcElemType.Kind() == reflect.Pointer { + srcElemType = srcElemType.Elem() + } - case *float64: - switch s := src.(type) { - case float64: - *v = s - opt.present = true - return nil - case float32: - *v = float64(s) - opt.present = true - return nil - } + if srcElemType != optElemType { + return Some(fmt.Errorf("Optional[%T] (aka %T) differs from %T!", opt.value, opt.value, src)) + } - case *float32: - switch s := src.(type) { - case float32: - *v = s - opt.present = true - return nil + if optType.Kind() == reflect.Pointer { + if srcVal.Kind() == reflect.Pointer { + opt.value = srcVal.Interface().(T) + } else { + newPtr := reflect.New(optElemType) + newPtr.Elem().Set(srcVal) + opt.value = newPtr.Interface().(T) } - - case *time.Time: - if t, ok := src.(time.Time); ok { - *v = t - opt.present = true - return nil - } - if b, ok := src.([]byte); ok { - parsedTime, err := time.Parse(time.RFC3339, string(b)) - if err == nil { - *v = parsedTime - opt.present = true - return nil - } + } else { + if srcVal.Kind() == reflect.Ptr { + opt.value = srcVal.Elem().Interface().(T) + } else { + opt.value = srcVal.Interface().(T) } } - // We couldnt parse the value - opt.present = false - return fmt.Errorf("unsupported type %T or differs from Optional[%T], and the type doesn't implement sql.Scanner", src, opt.value) + opt.present = true + return Some[error](nil) +} + +func (opt *Optional[T]) scanStringSpecial(src interface{}) Optional[error] { + opt.present = false + switch v := any(opt.value).(type) { + case *string: + switch s := src.(type) { + case []byte: + *v = string(s) + goto ok + case *[]byte: + *v = string(*s) + goto ok + } + case string: + switch s := src.(type) { + case []byte: + reflect.ValueOf(&opt.value).Elem().Set(reflect.ValueOf(string(s))) + goto ok + case *[]byte: + reflect.ValueOf(&opt.value).Elem().Set(reflect.ValueOf(string(*s))) + goto ok + } + } + return None[error]() +ok: + opt.present = true + return Some[error](nil) +} + +func (opt *Optional[T]) scanTimeSpecial(src interface{}) Optional[error] { + opt.present = false + switch v := any(opt.value).(type) { + case *time.Time: + switch t := src.(type) { + case []byte: + parsedTime, err := time.Parse(time.RFC3339, string(t)) + if err == nil { + *v = parsedTime + goto ok + } + case *[]byte: + parsedTime, err := time.Parse(time.RFC3339, string(*t)) + if err == nil { + *v = parsedTime + goto ok + } + } + case time.Time: + switch t := src.(type) { + case []byte: + parsedTime, err := time.Parse(time.RFC3339, string(t)) + if err == nil { + reflect.ValueOf(&opt.value).Elem().Set(reflect.ValueOf(parsedTime)) + goto ok + } + case *[]byte: + parsedTime, err := time.Parse(time.RFC3339, string(*t)) + if err == nil { + reflect.ValueOf(&opt.value).Elem().Set(reflect.ValueOf(parsedTime)) + goto ok + } + } + + } + return None[error]() +ok: + opt.present = true + return Some[error](nil) } diff --git a/type/optional_test.go b/type/optional_test.go index a2540f4..d32b3df 100644 --- a/type/optional_test.go +++ b/type/optional_test.go @@ -18,6 +18,20 @@ var ( nilOptional = (*Optional[int])(nil) ) +func Test_optionalScan(t *testing.T) { + a := None[string]() + b := None[*string]() + c := None[string]() + d := None[int]() + a.Scan("a") + b.Scan("b") + c.Scan(Ptr("c")) + Testing.AssertEqual(t, "a", a.Unwrap()) + Testing.AssertEqual(t, "b", *b.Unwrap()) + Testing.AssertEqual(t, "c", c.Unwrap()) + Testing.AssertPanic(t, func() { d.Unwrap() }) +} + func Test_optionalSome(t *testing.T) { Testing.AssertEqual(t, x, y) Testing.AssertEqual(t, v, u)