feat(Type): Scan method fixes for Result

This commit is contained in:
2025-01-04 01:55:26 +01:00
parent 1fb6970d5a
commit 80fa0d5279
3 changed files with 144 additions and 91 deletions

View File

@@ -49,8 +49,8 @@ var (
_ ValueContainer = (*Optional[any])(nil) _ ValueContainer = (*Optional[any])(nil)
_ ValueContainer = (*Result[any])(nil) _ ValueContainer = (*Result[any])(nil)
_ sql.Scanner = (*Optional[any])(nil) _ sql.Scanner = (*Optional[any])(nil)
// _ sql.Scanner = (*Result[any])(nil) _ sql.Scanner = (*Result[any])(nil)
_ json.Marshaler = (*Optional[any])(nil) _ json.Marshaler = (*Optional[any])(nil)
// _ json.Marshaler = (*Result[any])(nil) // _ json.Marshaler = (*Result[any])(nil)
_ json.Unmarshaler = (*Optional[any])(nil) _ json.Unmarshaler = (*Optional[any])(nil)
// _ json.Unmarshaler = (*Result[any])(nil) // _ json.Unmarshaler = (*Result[any])(nil)

View File

@@ -4,6 +4,7 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"reflect"
"time" "time"
Assert "github.com/lbatuska/goutils/assert" Assert "github.com/lbatuska/goutils/assert"
@@ -139,118 +140,155 @@ func (res *Result[T]) Err() Optional[error] {
} }
func (res *Result[T]) Scan(src interface{}) error { func (res *Result[T]) Scan(src interface{}) error {
Assert.NotNil(res)
e := fmt.Errorf("Unsupported type %T or differs from Result[%T], and the type doesn't implement sql.Scanner!",
src, res.value)
res.err = e
// DB had a null value // DB had a null value
if src == nil { if src == nil {
res.err = errors.New("src was nil!")
return nil return nil
} }
// If T is a scanner // If T is a scanner
if scanner, ok := any(&res.value).(sql.Scanner); ok { if scanner, ok := any(res.value).(sql.Scanner); ok {
if err := scanner.Scan(src); err != nil { if err := scanner.Scan(src); err != nil {
res.err = err
return err return err
} }
res.err = nil res.err = nil
return nil return nil
} }
// We implement parsing for some builtin types
mismatchErr := fmt.Errorf("Type of src (%T) doesn't match type of Result[%T]!", src, res.value)
switch v := any(&res.value).(type) { if scanres := res.scanBuiltin(src); scanres.IsSome() {
return scanres.Unwrap()
}
return e
}
func (res *Result[T]) scanBuiltin(src interface{}) Optional[error] {
res.err = nil
// First handle the special cases where we allow conversion between types
// This is usually just parsing []byte into type
if scanres := res.scanTimeSpecial(src); scanres.IsSome() {
return scanres
}
if scanres := res.scanStringSpecial(src); scanres.IsSome() {
return scanres
}
srcVal := reflect.ValueOf(src)
optType := reflect.TypeOf(res.value)
optElemType := optType
if optElemType.Kind() == reflect.Pointer {
optElemType = optElemType.Elem()
}
srcElemType := srcVal.Type()
if srcElemType.Kind() == reflect.Pointer {
srcElemType = srcElemType.Elem()
}
if srcElemType != optElemType {
e := fmt.Errorf("Result[%T] (aka %T) differs from %T!", res.value, res.value, src)
res.err = e
return Some(e)
}
if optType.Kind() == reflect.Pointer {
if srcVal.Kind() == reflect.Pointer {
res.value = srcVal.Interface().(T)
} else {
newPtr := reflect.New(optElemType)
newPtr.Elem().Set(srcVal)
res.value = newPtr.Interface().(T)
}
} else {
if srcVal.Kind() == reflect.Pointer {
res.value = srcVal.Elem().Interface().(T)
} else {
res.value = srcVal.Interface().(T)
}
}
res.err = nil
return Some[error](nil)
}
func (res *Result[T]) scanStringSpecial(src interface{}) Optional[error] {
switch v := any(res.value).(type) {
case *string: case *string:
if str, ok := src.(string); ok {
*v = str
res.err = nil
return nil
}
if b, ok := src.([]byte); ok {
*v = string(b)
res.err = nil
return nil
}
res.err = mismatchErr
return res.err
case *int:
if s, ok := src.(int); ok {
*v = s
res.err = nil
return nil
}
res.err = mismatchErr
return res.err
case *int32:
if s, ok := src.(int32); ok {
*v = s
res.err = nil
return nil
}
res.err = mismatchErr
return res.err
case *int64:
if s, ok := src.(int64); ok {
*v = s
res.err = nil
return nil
}
res.err = mismatchErr
return res.err
case *bool:
switch s := src.(type) { switch s := src.(type) {
case bool: case []byte:
*v = s *v = string(s)
res.err = nil goto ok
return nil case *[]byte:
*v = string(*s)
goto ok
} }
res.err = mismatchErr case string:
return res.err
case *float64:
switch s := src.(type) { switch s := src.(type) {
case float64: case []byte:
*v = s reflect.ValueOf(&res.value).Elem().Set(reflect.ValueOf(string(s)))
res.err = nil goto ok
return nil case *[]byte:
case float32: reflect.ValueOf(&res.value).Elem().Set(reflect.ValueOf(string(*s)))
*v = float64(s) goto ok
res.err = nil
return nil
} }
res.err = mismatchErr }
return res.err return None[error]()
ok:
case *float32: res.err = nil
switch s := src.(type) { return Some[error](nil)
case float32: }
*v = s
res.err = nil
return nil
}
res.err = mismatchErr
return res.err
func (res *Result[T]) scanTimeSpecial(src interface{}) Optional[error] {
switch v := any(res.value).(type) {
case *time.Time: case *time.Time:
if t, ok := src.(time.Time); ok { switch t := src.(type) {
*v = t case []byte:
res.err = nil parsedTime, err := time.Parse(time.RFC3339, string(t))
return nil
}
if b, ok := src.([]byte); ok {
parsedTime, err := time.Parse(time.RFC3339, string(b)) // or use other formats as necessary
if err == nil { if err == nil {
*v = parsedTime *v = parsedTime
res.err = nil goto ok
return nil } else {
res.err = err
return Some(err)
}
case *[]byte:
parsedTime, err := time.Parse(time.RFC3339, string(*t))
if err == nil {
*v = parsedTime
goto ok
} else {
res.err = err
return Some(err)
} }
} }
res.err = mismatchErr case time.Time:
return res.err switch t := src.(type) {
case []byte:
parsedTime, err := time.Parse(time.RFC3339, string(t))
if err == nil {
reflect.ValueOf(&res.value).Elem().Set(reflect.ValueOf(parsedTime))
goto ok
} else {
res.err = err
return Some(err)
}
case *[]byte:
parsedTime, err := time.Parse(time.RFC3339, string(*t))
if err == nil {
reflect.ValueOf(&res.value).Elem().Set(reflect.ValueOf(parsedTime))
goto ok
} else {
res.err = err
return Some(err)
}
}
} }
// We couldnt parse the value return None[error]()
err := fmt.Errorf("Unsupported type %T, and the type doesn't implement sql.Scanner!", src) ok:
res.err = err res.err = nil
return err return Some[error](nil)
} }

View File

@@ -15,6 +15,21 @@ var (
nilResult = (*Result[int])(nil) nilResult = (*Result[int])(nil)
) )
func Test_resultScan(t *testing.T) {
a := Ok("Not A!")
b := Ok(0)
c := Ok(0)
d := Err[*string](errors.New(""))
a.Scan("a")
b.Scan(1)
c.Scan("c")
d.Scan(Ptr("d"))
Testing.AssertEqual(t, a.Unwrap(), "a")
Testing.AssertEqual(t, b.Unwrap(), 1)
Testing.AssertError(t, c.UnwrapErr())
Testing.AssertEqual(t, *d.Unwrap(), "d")
}
func Test_resultConstructors(t *testing.T) { func Test_resultConstructors(t *testing.T) {
err := errors.New("some error") err := errors.New("some error")
w := Err[int](err) w := Err[int](err)