diff --git a/.travis.yml b/.travis.yml index 788ee8e..92b448d 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,7 +16,8 @@ install: script: - make install-dev-deps - make check-style - - make test + - make test-coverage + - make codecov matrix: fast_finish: true diff --git a/bpe.go b/bpe.go new file mode 100644 index 0000000..aca81fd --- /dev/null +++ b/bpe.go @@ -0,0 +1,178 @@ +package bpe + +import ( + "encoding/binary" + "errors" + "io" + + "github.com/sirupsen/logrus" +) + +// TokenID is a numerical identitier of the subword token +type TokenID uint32 + +// EncodedToken is a sequence of subword tokens ids +type EncodedToken []TokenID + +type rule struct { + left TokenID + right TokenID + result TokenID +} + +type specialTokens struct { + unk int32 + pad int32 + bos int32 + eos int32 +} + +// Model is a Byte-Pair encoding model, which supports encoding and decoding text into sequences +// of most frequent subword tokens +type Model struct { + char2id map[rune]TokenID + id2char map[TokenID]rune + rules []rule + recipe map[TokenID]EncodedToken + revRecipe map[string]TokenID + specialTokens specialTokens +} + +func newModel(nRules int) *Model { + return &Model{ + make(map[rune]TokenID), + make(map[TokenID]rune), + make([]rule, nRules), + make(map[TokenID]EncodedToken), + make(map[string]TokenID), + specialTokens{-1, -1, -1, -1}, + } +} + +// DecodeToken converts the sequence of chars' ids into the string - +// sequence of the corresponding chars +func DecodeToken(token EncodedToken, id2char map[TokenID]rune) (string, error) { + word := "" + for _, id := range token { + if char, ok := id2char[id]; ok { + word = word + string(char) + } else { + logrus.Errorf("Decode failure: %d token id has no corresponding char", id) + return "", errors.New("key not found in id2char") + } + } + return word, nil +} + +func (s specialTokens) toBinary() []byte { + bytesArray := make([]byte, 16) + binary.BigEndian.PutUint32(bytesArray, uint32(s.unk)) + binary.BigEndian.PutUint32(bytesArray[4:], uint32(s.pad)) + binary.BigEndian.PutUint32(bytesArray[8:], uint32(s.bos)) + binary.BigEndian.PutUint32(bytesArray[12:], uint32(s.eos)) + return bytesArray +} + +func binaryToSpecialTokens(bytesArray []byte) (specialTokens, error) { + var s specialTokens + if len(bytesArray) < 16 { + logrus.Error("Bytes array length is too small") + return s, errors.New("bytes array is too small") + } + s.unk = int32(binary.BigEndian.Uint32(bytesArray)) + s.pad = int32(binary.BigEndian.Uint32(bytesArray[4:])) + s.bos = int32(binary.BigEndian.Uint32(bytesArray[8:])) + s.eos = int32(binary.BigEndian.Uint32(bytesArray[12:])) + return s, nil +} + +func (r rule) toBinary() []byte { + bytesArray := make([]byte, 12) + binary.BigEndian.PutUint32(bytesArray, uint32(r.left)) + binary.BigEndian.PutUint32(bytesArray[4:], uint32(r.right)) + binary.BigEndian.PutUint32(bytesArray[8:], uint32(r.result)) + return bytesArray +} + +func binaryToRule(bytesArray []byte) (rule, error) { + var r rule + if len(bytesArray) < 12 { + logrus.Error("Bytes array length is too small") + return r, errors.New("bytes array is too small") + } + r.left = TokenID(binary.BigEndian.Uint32(bytesArray)) + r.right = TokenID(binary.BigEndian.Uint32(bytesArray[4:])) + r.result = TokenID(binary.BigEndian.Uint32(bytesArray[8:])) + return r, nil +} + +// ReadModel loads the BPE model from the binary dump +func ReadModel(reader io.Reader) (*Model, error) { + buf := make([]byte, 4) + var nChars, nRules int + if _, err := io.ReadFull(reader, buf); err != nil { + logrus.Error("Broken input: ", err) + return &Model{}, err + } + nChars = int(binary.BigEndian.Uint32(buf)) + if _, err := io.ReadFull(reader, buf); err != nil { + logrus.Error("Broken input: ", err) + return &Model{}, err + } + nRules = int(binary.BigEndian.Uint32(buf)) + + model := newModel(nRules) + for i := 0; i < nChars; i++ { + var char rune + var charID TokenID + if _, err := io.ReadFull(reader, buf); err != nil { + logrus.Error("Broken input: ", err) + return &Model{}, err + } + char = rune(binary.BigEndian.Uint32(buf)) + if _, err := io.ReadFull(reader, buf); err != nil { + logrus.Error("Broken input: ", err) + return &Model{}, err + } + charID = TokenID(binary.BigEndian.Uint32(buf)) + model.char2id[char] = charID + model.id2char[charID] = char + model.recipe[charID] = EncodedToken{charID} + model.revRecipe[string(char)] = charID + } + ruleBuf := make([]byte, 12) + for i := 0; i < nRules; i++ { + if _, err := io.ReadFull(reader, ruleBuf); err != nil { + logrus.Error("Broken input: ", err) + return &Model{}, err + } + rule, err := binaryToRule(ruleBuf) + if err != nil { + return model, err + } + model.rules[i] = rule + if _, ok := model.recipe[rule.left]; !ok { + logrus.Errorf("%d: token id not described before", rule.left) + return model, errors.New("key not found in id2char") + } + if _, ok := model.recipe[rule.right]; !ok { + logrus.Errorf("%d: token id not described before", rule.right) + return model, errors.New("key not found in id2char") + } + model.recipe[rule.result] = append(model.recipe[rule.left], model.recipe[rule.right]...) + resultString, err := DecodeToken(model.recipe[rule.result], model.id2char) + if err != nil { + logrus.Error("Unexpected token id inside the rules: ", err) + return model, err + } + model.revRecipe[resultString] = rule.result + } + specialTokensBuf := make([]byte, 16) + if _, err := io.ReadFull(reader, specialTokensBuf); err != nil { + logrus.Error("Broken input: ", err) + return &Model{}, err + } + specials, err := binaryToSpecialTokens(specialTokensBuf) + model.specialTokens = specials + return model, err +} diff --git a/bpe_test.go b/bpe_test.go new file mode 100644 index 0000000..e69398b --- /dev/null +++ b/bpe_test.go @@ -0,0 +1,131 @@ +package bpe + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewModel(t *testing.T) { + model := newModel(10) + require.Equal(t, 10, len(model.rules)) +} + +func TestDecodeToken(t *testing.T) { + id2char := map[TokenID]rune{1: []rune("a")[0], 2: []rune("b")[0], 3: []rune("c")[0]} + word, err := DecodeToken(EncodedToken{1, 2, 1, 3, 3}, id2char) + require.NoError(t, err) + require.Equal(t, "abacc", word) +} + +func TestSpecialTokensToBinary(t *testing.T) { + specials := specialTokens{1, 259, 2*256*256 + 37*256 + 2, -256 * 256 * 256 * 127} + bytesArray := []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0, 0} + require.Equal(t, bytesArray, specials.toBinary()) +} + +func TestBinaryToSpecialTokens(t *testing.T) { + bytesArray := []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0, 0} + expected := specialTokens{1, 259, 2*256*256 + 37*256 + 2, -256 * 256 * 256 * 127} + specials, err := binaryToSpecialTokens(bytesArray) + require.NoError(t, err) + require.Equal(t, expected, specials) + bytesArray = []byte{0, 0, 0, 1, 0, 0, 1, 3, 0, 2, 37, 2, 129, 0, 0} + specials, err = binaryToSpecialTokens(bytesArray) + require.Error(t, err) + bytesArray = []byte{} + specials, err = binaryToSpecialTokens(bytesArray) + require.Error(t, err) +} + +func TestRuleToBinary(t *testing.T) { + rule := rule{1, 2, 257} + bytesArray := []byte{0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1, 1} + require.Equal(t, bytesArray, rule.toBinary()) +} + +func TestBinaryToRule(t *testing.T) { + expected := rule{1, 2, 257} + bytesArray := []byte{0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 1, 1} + rule, err := binaryToRule(bytesArray) + require.NoError(t, err) + require.Equal(t, expected, rule) + bytesArray = []byte{0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 1} + rule, err = binaryToRule(bytesArray) + require.Error(t, err) + bytesArray = []byte{} + rule, err = binaryToRule(bytesArray) + require.Error(t, err) +} + +func TestReadModel(t *testing.T) { + reader := bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4, + 0, 0, 0, 99, 0, 0, 0, 6, + 0, 0, 0, 98, 0, 0, 0, 7, + 0, 0, 0, 95, 0, 0, 0, 4, + 0, 0, 0, 100, 0, 0, 0, 5, + 0, 0, 0, 97, 0, 0, 0, 8, + 0, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 9, + 0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10, + 0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11, + 0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12, + 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3}) + expected := Model{ + map[rune]TokenID{97: 8, 98: 7, 99: 6, 100: 5, 95: 4}, + map[TokenID]rune{4: 95, 5: 100, 6: 99, 7: 98, 8: 97}, + []rule{{4, 8, 9}, {4, 6, 10}, {4, 5, 11}, {4, 7, 12}}, + map[TokenID]EncodedToken{4: {4}, 5: {5}, 6: {6}, 7: {7}, 8: {8}, 9: {4, 8}, 10: {4, 6}, 11: {4, 5}, 12: {4, 7}}, + map[string]TokenID{"a": 8, "b": 7, "c": 6, "d": 5, "_": 4, + "_a": 9, "_b": 12, "_c": 10, "_d": 11}, + specialTokens{1, 0, 2, 3}, + } + model, err := ReadModel(reader) + require.NoError(t, err) + require.Equal(t, expected, *model) + + reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4, + 0, 0, 0, 99, 0, 0, 0, 6, + 0, 0, 0, 98, 0, 0, 0, 7, + 0, 0, 0, 95, 0, 0, 0, 4, + 0, 0, 0, 100, 0, 0, 0, 5, + 0, 0, 0, 97, 0, 0, 0, 8, + 0, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 9, + 0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10, + 0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11, + 0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12, + 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3, + 0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11, + 0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12}) + model, err = ReadModel(reader) + require.NoError(t, err) + require.Equal(t, expected, *model) + + reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4, + 0, 0, 0, 99, 0, 0, 0, 6, + 0, 0, 0, 98, 0, 0, 0, 7, + 0, 0, 0, 95, 0, 0, 0, 4, + 0, 0, 0, 100, 0, 0, 0, 5, + 0, 0, 0, 97, 0, 0, 0, 8, + 0, 0, 0, 4, 0, 0, 0, 8, 0, 0, 0, 9, + 0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10, + 0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11, + 0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12, + 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0}) + model, err = ReadModel(reader) + require.Error(t, err) + + reader = bytes.NewReader([]byte{0, 0, 0, 5, 0, 0, 0, 4, + 0, 0, 0, 99, 0, 0, 0, 6, + 0, 0, 0, 98, 0, 0, 0, 7, + 0, 0, 0, 95, 0, 0, 0, 4, + 0, 0, 0, 100, 0, 0, 0, 5, + 0, 0, 0, 97, 0, 0, 0, 8, + 0, 0, 0, 4, 0, 0, 0, 20, 0, 0, 0, 9, + 0, 0, 0, 4, 0, 0, 0, 6, 0, 0, 0, 10, + 0, 0, 0, 4, 0, 0, 0, 5, 0, 0, 0, 11, + 0, 0, 0, 4, 0, 0, 0, 7, 0, 0, 0, 12, + 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 3}) + model, err = ReadModel(reader) + require.Error(t, err) +} diff --git a/go.mod b/go.mod index 7d8e300..510e312 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,8 @@ module github.com/src-d/go-YouTokenToMe go 1.12 + +require ( + github.com/sirupsen/logrus v1.4.2 + github.com/stretchr/testify v1.4.0 +) diff --git a/go.sum b/go.sum deleted file mode 100644 index e69de29..0000000 diff --git a/main.go b/main.go deleted file mode 100644 index 953cd3d..0000000 --- a/main.go +++ /dev/null @@ -1,7 +0,0 @@ -package main - -import "fmt" - -func main() { - fmt.Printf("Package for applying BPE") -}