fix(Type): Scan method fixes

This commit is contained in:
2025-01-01 22:24:59 +01:00
parent b5fc0df93b
commit 864953076a
3 changed files with 136 additions and 87 deletions

View File

@@ -1,5 +1,7 @@
package Type package Type
import "database/sql"
// Created to abstract over Is_some and Is_ok // Created to abstract over Is_some and Is_ok
type ValueContainer interface { type ValueContainer interface {
HasValue() bool HasValue() bool
@@ -43,4 +45,6 @@ var (
_ Resulter[any] = (*Result[any])(nil) _ Resulter[any] = (*Result[any])(nil)
_ ValueContainer = (*Optional[any])(nil) _ ValueContainer = (*Optional[any])(nil)
_ ValueContainer = (*Result[any])(nil) _ ValueContainer = (*Result[any])(nil)
_ sql.Scanner = (*Optional[any])(nil)
// _ sql.Scanner = (*Result[any])(nil)
) )

View File

@@ -3,6 +3,7 @@ package Type
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"reflect"
"time" "time"
Assert "github.com/lbatuska/goutils/assert" Assert "github.com/lbatuska/goutils/assert"
@@ -114,111 +115,141 @@ func (opt *Optional[T]) OkOrElse(f func() error) Result[T] {
} }
func (opt *Optional[T]) Scan(src interface{}) error { func (opt *Optional[T]) Scan(src interface{}) error {
Assert.NotNil(opt)
opt.present = false
// DB had a null value // DB had a null value
if src == nil { if src == nil {
opt.present = false
return nil return nil
} }
// If T is a scanner // If T is a scanner
if scanner, ok := any(&opt.value).(sql.Scanner); ok { if scanner, ok := any(opt.value).(sql.Scanner); ok {
if err := scanner.Scan(src); err != nil { if err := scanner.Scan(src); err != nil {
opt.present = false
return err return err
} }
opt.present = true opt.present = true
return nil return nil
} }
// We implement parsing for some builtin types
if scanres := opt.scanBuiltin(src); scanres.IsSome() {
return scanres.Unwrap()
}
return fmt.Errorf("Unsupported type %T or differs from Optional[%T], and the type doesn't implement sql.Scanner!", src, opt.value)
}
func (opt *Optional[T]) scanBuiltin(src interface{}) Optional[error] {
opt.present = false opt.present = false
switch v := any(&opt.value).(type) { // First handle the special cases where we allow conversion between types
// This is usually just parsing []byte into type
if scanres := opt.scanTimeSpecial(src); scanres.IsSome() {
return scanres
}
if scanres := opt.scanStringSpecial(src); scanres.IsSome() {
return scanres
}
case *string: srcVal := reflect.ValueOf(src)
if str, ok := src.(string); ok { optType := reflect.TypeOf(opt.value)
*v = str
opt.present = true
return nil
}
if b, ok := src.([]byte); ok {
*v = string(b)
opt.present = true
return nil
}
case *int: optElemType := optType
if s, ok := src.(int); ok { if optElemType.Kind() == reflect.Pointer {
*v = s optElemType = optElemType.Elem()
opt.present = true }
return nil
}
case *int32:
if s, ok := src.(int32); ok {
*v = s
opt.present = true
return nil
}
case *int64:
if s, ok := src.(int64); ok {
*v = s
opt.present = true
return nil
}
case *bool: srcElemType := srcVal.Type()
switch s := src.(type) { if srcElemType.Kind() == reflect.Pointer {
case bool: srcElemType = srcElemType.Elem()
*v = s }
opt.present = true
return nil
// We could technically allow this, however we try to avoid implicit conversions to ensure type safety.
// case string:
// if s == "1" || s == "true" || s == "t" {
// *v = true
// opt.present = true
// return nil
// }
// if s == "0" || s == "false" || s == "f" {
// *v = false
// opt.present = true
// return nil
// }
}
case *float64: if srcElemType != optElemType {
switch s := src.(type) { return Some(fmt.Errorf("Optional[%T] (aka %T) differs from %T!", opt.value, opt.value, src))
case float64: }
*v = s
opt.present = true
return nil
case float32:
*v = float64(s)
opt.present = true
return nil
}
case *float32: if optType.Kind() == reflect.Pointer {
switch s := src.(type) { if srcVal.Kind() == reflect.Pointer {
case float32: opt.value = srcVal.Interface().(T)
*v = s } else {
opt.present = true newPtr := reflect.New(optElemType)
return nil newPtr.Elem().Set(srcVal)
opt.value = newPtr.Interface().(T)
} }
} else {
case *time.Time: if srcVal.Kind() == reflect.Ptr {
if t, ok := src.(time.Time); ok { opt.value = srcVal.Elem().Interface().(T)
*v = t } else {
opt.present = true opt.value = srcVal.Interface().(T)
return nil
}
if b, ok := src.([]byte); ok {
parsedTime, err := time.Parse(time.RFC3339, string(b))
if err == nil {
*v = parsedTime
opt.present = true
return nil
}
} }
} }
// We couldnt parse the value opt.present = true
opt.present = false return Some[error](nil)
return fmt.Errorf("unsupported type %T or differs from Optional[%T], and the type doesn't implement sql.Scanner", src, opt.value) }
func (opt *Optional[T]) scanStringSpecial(src interface{}) Optional[error] {
opt.present = false
switch v := any(opt.value).(type) {
case *string:
switch s := src.(type) {
case []byte:
*v = string(s)
goto ok
case *[]byte:
*v = string(*s)
goto ok
}
case string:
switch s := src.(type) {
case []byte:
reflect.ValueOf(&opt.value).Elem().Set(reflect.ValueOf(string(s)))
goto ok
case *[]byte:
reflect.ValueOf(&opt.value).Elem().Set(reflect.ValueOf(string(*s)))
goto ok
}
}
return None[error]()
ok:
opt.present = true
return Some[error](nil)
}
func (opt *Optional[T]) scanTimeSpecial(src interface{}) Optional[error] {
opt.present = false
switch v := any(opt.value).(type) {
case *time.Time:
switch t := src.(type) {
case []byte:
parsedTime, err := time.Parse(time.RFC3339, string(t))
if err == nil {
*v = parsedTime
goto ok
}
case *[]byte:
parsedTime, err := time.Parse(time.RFC3339, string(*t))
if err == nil {
*v = parsedTime
goto ok
}
}
case time.Time:
switch t := src.(type) {
case []byte:
parsedTime, err := time.Parse(time.RFC3339, string(t))
if err == nil {
reflect.ValueOf(&opt.value).Elem().Set(reflect.ValueOf(parsedTime))
goto ok
}
case *[]byte:
parsedTime, err := time.Parse(time.RFC3339, string(*t))
if err == nil {
reflect.ValueOf(&opt.value).Elem().Set(reflect.ValueOf(parsedTime))
goto ok
}
}
}
return None[error]()
ok:
opt.present = true
return Some[error](nil)
} }

View File

@@ -18,6 +18,20 @@ var (
nilOptional = (*Optional[int])(nil) nilOptional = (*Optional[int])(nil)
) )
func Test_optionalScan(t *testing.T) {
a := None[string]()
b := None[*string]()
c := None[string]()
d := None[int]()
a.Scan("a")
b.Scan("b")
c.Scan(Ptr("c"))
Testing.AssertEqual(t, "a", a.Unwrap())
Testing.AssertEqual(t, "b", *b.Unwrap())
Testing.AssertEqual(t, "c", c.Unwrap())
Testing.AssertPanic(t, func() { d.Unwrap() })
}
func Test_optionalSome(t *testing.T) { func Test_optionalSome(t *testing.T) {
Testing.AssertEqual(t, x, y) Testing.AssertEqual(t, x, y)
Testing.AssertEqual(t, v, u) Testing.AssertEqual(t, v, u)