diff --git a/cache.go b/cache.go index 81821ad..aeea675 100644 --- a/cache.go +++ b/cache.go @@ -83,13 +83,29 @@ package cache // 2 import ( + "encoding/gob" "fmt" + "io" + "os" "reflect" "runtime" "sync" "time" ) +type Item struct { + Object interface{} + Expiration *time.Time +} + +// Returns true if the item has expired. +func (i *Item) Expired() bool { + if i.Expiration == nil { + return false + } + return i.Expiration.Before(time.Now()) +} + type Cache struct { *cache // If this is confusing, see the comment at the bottom of the New() function @@ -102,16 +118,6 @@ type cache struct { janitor *janitor } -type Item struct { - Object interface{} - Expiration *time.Time -} - -type janitor struct { - Interval time.Duration - stop chan bool -} - // Adds an item to the cache, replacing any existing item. If the duration is 0, the // cache's default expiration time is used. If it is -1, the item never expires. func (c *cache) Set(k string, x interface{}, d time.Duration) { @@ -272,6 +278,63 @@ func (c *cache) DeleteExpired() { } } +// Writes the cache's items using Gob to an io.Writer +func (c *cache) Save(w io.Writer) error { + enc := gob.NewEncoder(w) + + defer func() { + if x := recover(); x != nil { + fmt.Printf(`The Gob library paniced while registering the cache's item types! +Information: %v + +The cache will not be saved. +Please report under what conditions this happened, and particularly what special type of objects +were stored in cache, at https://github.com/pmylund/go-cache/issues/new`, x) + } + }() + for _, v := range c.Items { + gob.Register(v.Object) + } + err := enc.Encode(&c.Items) + return err +} + +// Saves the cache's items to the given filename, creating the file if it +// doesn't exist, and overwriting it if it does +func (c *cache) SaveFile(fname string) error { + fp, err := os.Create(fname) + if err != nil { + return err + } + return c.Save(fp) +} + +// Adds gob-serialized cache items from an io.Reader, excluding any items that +// already exist in the current cache +func (c *cache) Load(r io.Reader) error { + dec := gob.NewDecoder(r) + items := map[string]*Item{} + err := dec.Decode(&items) + if err == nil { + for k, v := range items { + _, found := c.Items[k] + if !found { + c.Items[k] = v + } + } + } + return err +} + +// Loads and adds cache items from the given filename +func (c *cache) LoadFile(fname string) error { + fp, err := os.Open(fname) + if err != nil { + return err + } + return c.Load(fp) +} + // Deletes all items from the cache. func (c *cache) Flush() { c.mu.Lock() @@ -280,12 +343,9 @@ func (c *cache) Flush() { c.Items = map[string]*Item{} } -// Returns true if the item has expired. -func (i *Item) Expired() bool { - if i.Expiration == nil { - return false - } - return i.Expiration.Before(time.Now()) +type janitor struct { + Interval time.Duration + stop chan bool } func (j *janitor) Run(c *cache) { diff --git a/cache_test.go b/cache_test.go index 4e4a7f7..c11c8a6 100644 --- a/cache_test.go +++ b/cache_test.go @@ -1,6 +1,7 @@ package cache import ( + "bytes" "testing" "time" ) @@ -452,6 +453,104 @@ func TestDecrementUnderflowUint(t *testing.T) { } } +func TestCacheSerialization(t *testing.T) { + tc := New(0, 0) + testFillAndSerialize(t, tc) + + // Check if gob.Register behaves properly even after multiple gob.Register + // on c.Items (many of which will be the same type) + testFillAndSerialize(t, tc) +} + +func testFillAndSerialize(t *testing.T, tc *Cache) { + tc.Set("a", "a", 0) + tc.Set("b", "b", 0) + tc.Set("*struct", &TestStruct{Num: 1}, 0) + tc.Set("[]struct", []TestStruct{ + {Num: 2}, + {Num: 3}, + }, 0) + tc.Set("[]*struct", []*TestStruct{ + &TestStruct{Num: 4}, + &TestStruct{Num: 5}, + }, 0) + tc.Set("c", "c", 0) // ordering should be meaningless, but just in case + + fp := &bytes.Buffer{} + err := tc.Save(fp) + if err != nil { + t.Fatal("Couldn't save cache to fp:", err) + } + + oc := New(0, 0) + err = oc.Load(fp) + if err != nil { + t.Fatal("Couldn't load cache from fp:", err) + } + + a, found := oc.Get("a") + if !found { + t.Error("a was not found") + } + if a.(string) != "a" { + t.Error("a is not a") + } + + b, found := oc.Get("b") + if !found { + t.Error("b was not found") + } + if b.(string) != "b" { + t.Error("b is not b") + } + + c, found := oc.Get("c") + if !found { + t.Error("c was not found") + } + if c.(string) != "c" { + t.Error("c is not c") + } + + s1, found := oc.Get("*struct") + if !found { + t.Error("*struct was not found") + } + if s1.(*TestStruct).Num != 1 { + t.Error("*struct.Num is not 1") + } + + s2, found := oc.Get("[]struct") + if !found { + t.Error("[]struct was not found") + } + s2r := s2.([]TestStruct) + if len(s2r) != 2 { + t.Error("Length of s2r is not 2") + } + if s2r[0].Num != 2 { + t.Error("s2r[0].Num is not 2") + } + if s2r[1].Num != 3 { + t.Error("s2r[1].Num is not 3") + } + + s3, found := oc.get("[]*struct") + if !found { + t.Error("[]*struct was not found") + } + s3r := s3.([]*TestStruct) + if len(s3r) != 2 { + t.Error("Length of s3r is not 2") + } + if s3r[0].Num != 4 { + t.Error("s3r[0].Num is not 4") + } + if s3r[1].Num != 5 { + t.Error("s3r[1].Num is not 5") + } +} + func BenchmarkCacheGet(b *testing.B) { tc := New(0, 0) tc.Set("foo", "bar", 0)