Skip to content

Commit

Permalink
Expand tests
Browse files Browse the repository at this point in the history
Signed-off-by: Irina Khismatullina <[email protected]>
  • Loading branch information
irinakhismatullina committed Oct 22, 2019
1 parent 9bbc64d commit 39b93ca
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 47 deletions.
87 changes: 53 additions & 34 deletions bpe.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package bpe

import (
"bufio"
"encoding/binary"
"errors"
"io"

"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -57,60 +57,68 @@ func DecodeToken(token EncodedToken, id2char map[TokenID]rune) (string, error) {
if char, ok := id2char[id]; ok {
word = word + string(char)
} else {
logrus.Fatalf("%d key not found in id2char", id)
logrus.Errorf("Decode failure: %d token id has no corresponding char", id)
return "", errors.New("key not found in id2char")
}
}
return word, nil
}

func specialTokensToBin(specials specialTokens) []byte {
func (s specialTokens) toBinary() []byte {
bytesArray := make([]byte, 16)
binary.BigEndian.PutUint32(bytesArray, uint32(specials.unk))
binary.BigEndian.PutUint32(bytesArray[4:], uint32(specials.pad))
binary.BigEndian.PutUint32(bytesArray[8:], uint32(specials.bos))
binary.BigEndian.PutUint32(bytesArray[12:], uint32(specials.eos))
binary.BigEndian.PutUint32(bytesArray, uint32(s.unk))
binary.BigEndian.PutUint32(bytesArray[4:], uint32(s.pad))
binary.BigEndian.PutUint32(bytesArray[8:], uint32(s.bos))
binary.BigEndian.PutUint32(bytesArray[12:], uint32(s.eos))
return bytesArray
}

func binToSpecialTokens(bytesArray []byte) specialTokens {
func binaryToSpecialTokens(bytesArray []byte) (specialTokens, error) {
var s specialTokens
if len(bytesArray) < 16 {
logrus.Error("Bytes array length is too small")
return s, errors.New("bytes array is too small")
}
s.unk = int32(binary.BigEndian.Uint32(bytesArray))
s.pad = int32(binary.BigEndian.Uint32(bytesArray[4:]))
s.bos = int32(binary.BigEndian.Uint32(bytesArray[8:]))
s.eos = int32(binary.BigEndian.Uint32(bytesArray[12:]))
return s
return s, nil
}

func ruleToBin(rule rule) []byte {
func (r rule) toBinary() []byte {
bytesArray := make([]byte, 12)
binary.BigEndian.PutUint32(bytesArray, uint32(rule.left))
binary.BigEndian.PutUint32(bytesArray[4:], uint32(rule.right))
binary.BigEndian.PutUint32(bytesArray[8:], uint32(rule.result))
binary.BigEndian.PutUint32(bytesArray, uint32(r.left))
binary.BigEndian.PutUint32(bytesArray[4:], uint32(r.right))
binary.BigEndian.PutUint32(bytesArray[8:], uint32(r.result))
return bytesArray
}

func binToRule(bytesArray []byte) rule {
func binaryToRule(bytesArray []byte) (rule, error) {
var r rule
if len(bytesArray) < 12 {
logrus.Error("Bytes array length is too small")
return r, errors.New("bytes array is too small")
}
r.left = TokenID(binary.BigEndian.Uint32(bytesArray))
r.right = TokenID(binary.BigEndian.Uint32(bytesArray[4:]))
r.result = TokenID(binary.BigEndian.Uint32(bytesArray[8:]))
return r
return r, nil
}

// ReadModelFromBinary loads the BPE model from the binary dump
func ReadModelFromBinary(reader io.Reader) (*Model, error) {
bytesReader := bufio.NewReader(reader)
// ReadModel loads the BPE model from the binary dump
func ReadModel(reader io.Reader) (*Model, error) {
buf := make([]byte, 4)
var nChars, nRules int
_, err := bytesReader.Read(buf)
_, err := io.ReadFull(reader, buf)
if err != nil {
logrus.Fatal("Broken input: ", err)
logrus.Error("Broken input: ", err)
return &Model{}, err
}
nChars = int(binary.BigEndian.Uint32(buf))
_, err = bytesReader.Read(buf)
_, err = io.ReadFull(reader, buf)
if err != nil {
logrus.Fatal("Broken input: ", err)
logrus.Error("Broken input: ", err)
return &Model{}, err
}
nRules = int(binary.BigEndian.Uint32(buf))
Expand All @@ -119,15 +127,15 @@ func ReadModelFromBinary(reader io.Reader) (*Model, error) {
for i := 0; i < nChars; i++ {
var char rune
var charID TokenID
_, err = bytesReader.Read(buf)
_, err = io.ReadFull(reader, buf)
if err != nil {
logrus.Fatal("Broken input: ", err)
logrus.Error("Broken input: ", err)
return &Model{}, err
}
char = rune(binary.BigEndian.Uint32(buf))
_, err = bytesReader.Read(buf)
_, err = io.ReadFull(reader, buf)
if err != nil {
logrus.Fatal("Broken input: ", err)
logrus.Error("Broken input: ", err)
return &Model{}, err
}
charID = TokenID(binary.BigEndian.Uint32(buf))
Expand All @@ -138,27 +146,38 @@ func ReadModelFromBinary(reader io.Reader) (*Model, error) {
}
ruleBuf := make([]byte, 12)
for i := 0; i < nRules; i++ {
_, err = bytesReader.Read(ruleBuf)
_, err = io.ReadFull(reader, ruleBuf)
if err != nil {
logrus.Fatal("Broken input: ", err)
logrus.Error("Broken input: ", err)
return &Model{}, err
}
rule := binToRule(ruleBuf)
rule, err := binaryToRule(ruleBuf)
if err != nil {
return model, err
}
model.rules[i] = rule
if _, ok := model.recipe[rule.left]; !ok {
logrus.Errorf("%d: token id not described before", rule.left)
return model, errors.New("key not found in id2char")
}
if _, ok := model.recipe[rule.right]; !ok {
logrus.Errorf("%d: token id not described before", rule.right)
return model, errors.New("key not found in id2char")
}
model.recipe[rule.result] = append(model.recipe[rule.left], model.recipe[rule.right]...)
resultString, err := DecodeToken(model.recipe[rule.result], model.id2char)
if err != nil {
logrus.Fatal("Unexpected token id inside the rules: ", err)
logrus.Error("Unexpected token id inside the rules: ", err)
return model, err
}
model.revRecipe[resultString] = rule.result
}
specialTokensBuf := make([]byte, 16)
_, err = bytesReader.Read(specialTokensBuf)
_, err = io.ReadFull(reader, specialTokensBuf)
if err != nil {
logrus.Fatal("Broken input: ", err)
logrus.Error("Broken input: ", err)
return &Model{}, err
}
model.specialTokens = binToSpecialTokens(specialTokensBuf)
return model, nil
model.specialTokens, err = binaryToSpecialTokens(specialTokensBuf)
return model, err
}
87 changes: 74 additions & 13 deletions bpe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,54 @@ func TestNewModel(t *testing.T) {
require.Equal(t, 10, len(model.rules))
}

func TestDecodedTokenToString(t *testing.T) {
func TestDecodeToken(t *testing.T) {
id2char := map[TokenID]rune{1: []rune("a")[0], 2: []rune("b")[0], 3: []rune("c")[0]}
word, err := DecodeToken(EncodedToken{1, 2, 1, 3, 3}, id2char)
require.NoError(t, err)
require.Equal(t, "abacc", word)
}

func TestSpecialTokensToBin(t *testing.T) {
func TestSpecialTokensToBinary(t *testing.T) {
specials := specialTokens{1, 259, 2*256*256 + 37*256 + 2, -256 * 256 * 256 * 127}
bytesArray := []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0, 0}
require.Equal(t, bytesArray, specialTokensToBin(specials))
require.Equal(t, bytesArray, specials.toBinary())
}

func TestBinToSpecialTokens(t *testing.T) {
func TestBinaryToSpecialTokens(t *testing.T) {
bytesArray := []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0, 0}
specials := specialTokens{1, 259, 2*256*256 + 37*256 + 2, -256 * 256 * 256 * 127}
require.Equal(t, specials, binToSpecialTokens(bytesArray))
expected := specialTokens{1, 259, 2*256*256 + 37*256 + 2, -256 * 256 * 256 * 127}
specials, err := binaryToSpecialTokens(bytesArray)
require.NoError(t, err)
require.Equal(t, expected, specials)
bytesArray = []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0}
specials, err = binaryToSpecialTokens(bytesArray)
require.Error(t, err)
bytesArray = []byte{}
specials, err = binaryToSpecialTokens(bytesArray)
require.Error(t, err)
}

func TestRuleToBin(t *testing.T) {
func TestRuleToBinary(t *testing.T) {
rule := rule{1, 2, 257}
bytesArray := []byte{0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1, 1}
require.Equal(t, bytesArray, ruleToBin(rule))
require.Equal(t, bytesArray, rule.toBinary())
}

func TestBinToRule(t *testing.T) {
rule := rule{1, 2, 257}
func TestBinaryToRule(t *testing.T) {
expected := rule{1, 2, 257}
bytesArray := []byte{0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1, 1}
require.Equal(t, rule, binToRule(bytesArray))
rule, err := binaryToRule(bytesArray)
require.NoError(t, err)
require.Equal(t, expected, rule)
bytesArray = []byte{0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 1}
rule, err = binaryToRule(bytesArray)
require.Error(t, err)
bytesArray = []byte{}
rule, err = binaryToRule(bytesArray)
require.Error(t, err)
}

func TestReadModelFromBinary(t *testing.T) {
func TestReadModel(t *testing.T) {
reader := bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
0, 0, 0, 99, 0, 0, 0, 6,
0, 0, 0, 98, 0, 0, 0, 7,
Expand All @@ -64,7 +80,52 @@ func TestReadModelFromBinary(t *testing.T) {
"_a": 9, "_b": 12, "_c": 10, "_d": 11},
specialTokens{1, 0, 2, 3},
}
model, err := ReadModelFromBinary(reader)
model, err := ReadModel(reader)
require.NoError(t, err)
require.Equal(t, expected, *model)

reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
0, 0, 0, 99, 0, 0, 0, 6,
0, 0, 0, 98, 0, 0, 0, 7,
0, 0, 0, 95, 0, 0, 0, 4,
0, 0, 0, 100, 0, 0, 0, 5,
0, 0, 0, 97, 0, 0, 0, 8,
0, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 9,
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10,
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12,
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3,
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12})
model, err = ReadModel(reader)
require.NoError(t, err)
require.Equal(t, expected, *model)

reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
0, 0, 0, 99, 0, 0, 0, 6,
0, 0, 0, 98, 0, 0, 0, 7,
0, 0, 0, 95, 0, 0, 0, 4,
0, 0, 0, 100, 0, 0, 0, 5,
0, 0, 0, 97, 0, 0, 0, 8,
0, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 9,
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10,
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12,
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0})
model, err = ReadModel(reader)
require.Error(t, err)

reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4,
0, 0, 0, 99, 0, 0, 0, 6,
0, 0, 0, 98, 0, 0, 0, 7,
0, 0, 0, 95, 0, 0, 0, 4,
0, 0, 0, 100, 0, 0, 0, 5,
0, 0, 0, 97, 0, 0, 0, 8,
0, 0, 0, 4, 0, 0, 0, 20, 0, 0, 0, 9,
0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10,
0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11,
0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12,
0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3})
model, err = ReadModel(reader)
require.Error(t, err)
}
18 changes: 18 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sirupsen/logrus v1.4.2 h1:SPIRibHv4MatM3XXNO2BJeFLZwZ2LvZgfQ5+UNI2im4=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

0 comments on commit 39b93ca

Please sign in to comment.