Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add custom enum deserializer to improve error messaging, improve byte count error mesages #5076

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,13 @@ public static ByteCount parse(final String string) {
final String unitString = matcher.group("unit");

if(unitString == null) {
throw new ByteCountInvalidInputException("Byte counts must have a unit.");
throw new ByteCountInvalidInputException("Byte counts must have a unit. Valid byte units include: " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, it is better to enhance the corresponding test case messages to assert for this updated message. Currently, they are only looking for not null.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update the test case to assert for this message

Arrays.stream(Unit.values()).map(unitValue -> unitValue.unitString).collect(Collectors.toList()));
}

final Unit unit = Unit.fromString(unitString)
.orElseThrow(() -> new ByteCountInvalidInputException("Invalid byte unit: '" + unitString + "'"));
.orElseThrow(() -> new ByteCountInvalidInputException("Invalid byte unit: '" + unitString + "'. Valid byte units include: "
+ Arrays.stream(Unit.values()).map(unitValue -> unitValue.unitString).collect(Collectors.toList())));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably a nit. This Arrays.stream statement repeated twice. It is also there in line number 83. I think you can add a method in Unit enum class to return this array of code values to keep the logic at one place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 51 already does something similar, though it creates a map which will be out of order. If you make it a linked map you could use the keys() from it.

private static final Map<String, Unit> UNIT_MAP = Arrays.stream(Unit.values())
.collect(Collectors.toMap(unit -> unit.unitString, Function.identity()));


final BigDecimal valueBigDecimal = new BigDecimal(valueString);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public ByteCount deserialize(final JsonParser parser, final DeserializationConte
try {
return ByteCount.parse(byteString);
} catch (final Exception ex) {
throw new IllegalArgumentException(ex);
throw new IllegalArgumentException(ex.getMessage());
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package org.opensearch.dataprepper.pipeline.parser;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.BeanProperty;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.deser.ContextualDeserializer;

import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;


/**
* This deserializer is used for any Enum classes when converting the pipeline configuration file into the plugin model classes
* @since 2.11
*/
public class EnumDeserializer extends JsonDeserializer<Enum<?>> implements ContextualDeserializer {

static final String INVALID_ENUM_VALUE_ERROR_FORMAT = "Invalid value \"%s\". Valid options include %s.";

private Class<?> enumClass;

public EnumDeserializer() {}

public EnumDeserializer(final Class<?> enumClass) {
if (!enumClass.isEnum()) {
throw new IllegalArgumentException("The provided class is not an enum: " + enumClass.getName());
}

this.enumClass = enumClass;
}
@Override
public Enum<?> deserialize(final JsonParser p, final DeserializationContext ctxt) throws IOException {
final JsonNode node = p.getCodec().readTree(p);
final String enumValue = node.asText();

final Optional<Method> jsonCreator = findJsonCreatorMethod();

try {
jsonCreator.ifPresent(method -> method.setAccessible(true));

for (Object enumConstant : enumClass.getEnumConstants()) {
try {
if (jsonCreator.isPresent() && enumConstant.equals(jsonCreator.get().invoke(null, enumValue))) {
return (Enum<?>) enumConstant;
} else if (jsonCreator.isEmpty() && enumConstant.toString().toLowerCase().equals(enumValue)) {
return (Enum<?>) enumConstant;
}
} catch (IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
}
} finally {
jsonCreator.ifPresent(method -> method.setAccessible(false));
}



final Optional<Method> jsonValueMethod = findJsonValueMethodForClass();
final List<Object> listOfEnums = jsonValueMethod.map(method -> Arrays.stream(enumClass.getEnumConstants())
.map(valueEnum -> {
try {
return method.invoke(valueEnum);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException(e);
}
})
.collect(Collectors.toList())).orElseGet(() -> Arrays.stream(enumClass.getEnumConstants())
.map(valueEnum -> valueEnum.toString().toLowerCase())
.collect(Collectors.toList()));

throw new IllegalArgumentException(String.format(INVALID_ENUM_VALUE_ERROR_FORMAT, enumValue, listOfEnums));
}

@Override
public JsonDeserializer<?> createContextual(final DeserializationContext ctxt, final BeanProperty property) {
final JavaType javaType = property.getType();
final Class<?> rawClass = javaType.getRawClass();

return new EnumDeserializer(rawClass);
}

private Optional<Method> findJsonValueMethodForClass() {
for (final Method method : enumClass.getDeclaredMethods()) {
if (method.isAnnotationPresent(JsonValue.class)) {
return Optional.of(method);
}
}

return Optional.empty();
}

private Optional<Method> findJsonCreatorMethod() {
for (final Method method : enumClass.getDeclaredMethods()) {
if (method.isAnnotationPresent(JsonCreator.class)) {
return Optional.of(method);
}
}

return Optional.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.junit.jupiter.params.provider.ValueSource;
import org.opensearch.dataprepper.model.types.ByteCount;

import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.notNullValue;
import static org.hamcrest.MatcherAssert.assertThat;
Expand All @@ -31,9 +32,28 @@ void setUp() {
}

@ParameterizedTest
@ValueSource(strings = {"1", "1b 2b", "1vb", "bad"})
void convert_with_invalid_values_throws(final String invalidByteString) {
assertThrows(IllegalArgumentException.class, () -> objectMapper.convertValue(invalidByteString, ByteCount.class));
@ValueSource(strings = {"1", "10"})
void convert_with_no_byte_unit_throws_expected_exception(final String invalidByteString) {
final IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> objectMapper.convertValue(invalidByteString, ByteCount.class));
assertThat(exception.getMessage(), containsString("Byte counts must have a unit. Valid byte units include: [b, kb, mb, gb]"));
}

@ParameterizedTest
@ValueSource(strings = {"10 2b", "bad"})
void convert_with_non_parseable_values_throws(final String invalidByteString) {
final IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> objectMapper.convertValue(invalidByteString, ByteCount.class));
assertThat(exception.getMessage(), containsString("Unable to parse bytes"));
}

@ParameterizedTest
@CsvSource({
"10f, f",
"1vb, vb",
"3g, g"
})
void convert_with_invalid_byte_units_throws(final String invalidByteString, final String invalidUnit) {
final IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> objectMapper.convertValue(invalidByteString, ByteCount.class));
assertThat(exception.getMessage(), containsString("Invalid byte unit: '" + invalidUnit + "'. Valid byte units include: [b, kb, mb, gb]"));
}

@ParameterizedTest
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
package org.opensearch.dataprepper.pipeline.parser;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue;
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.BeanProperty;
import com.fasterxml.jackson.databind.DeserializationContext;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.TextNode;
import org.hamcrest.Matchers;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.opensearch.dataprepper.model.event.HandleFailedEventsOption;

import java.io.IOException;
import java.time.Duration;
import java.util.Arrays;
import java.util.Map;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.notNullValue;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class EnumDeserializerTest {

private ObjectMapper objectMapper;

@BeforeEach
void setup() {
objectMapper = mock(ObjectMapper.class);
}

private EnumDeserializer createObjectUnderTest(final Class<?> enumClass) {
return new EnumDeserializer(enumClass);
}

@Test
void non_enum_class_throws_IllegalArgumentException() {
assertThrows(IllegalArgumentException.class, () -> new EnumDeserializer(Duration.class));
}

@ParameterizedTest
@EnumSource(TestEnum.class)
void enum_class_with_json_creator_annotation_returns_expected_enum_constant(final TestEnum testEnumOption) throws IOException {
final EnumDeserializer objectUnderTest = createObjectUnderTest(TestEnum.class);
final JsonParser jsonParser = mock(JsonParser.class);
final DeserializationContext deserializationContext = mock(DeserializationContext.class);
when(jsonParser.getCodec()).thenReturn(objectMapper);

when(objectMapper.readTree(jsonParser)).thenReturn(new TextNode(testEnumOption.toString()));

Enum<?> result = objectUnderTest.deserialize(jsonParser, deserializationContext);

assertThat(result, equalTo(testEnumOption));
}

@ParameterizedTest
@EnumSource(TestEnumWithoutJsonCreator.class)
void enum_class_without_json_creator_annotation_returns_expected_enum_constant(final TestEnumWithoutJsonCreator enumWithoutJsonCreator) throws IOException {
final EnumDeserializer objectUnderTest = createObjectUnderTest(TestEnumWithoutJsonCreator.class);
final JsonParser jsonParser = mock(JsonParser.class);
final DeserializationContext deserializationContext = mock(DeserializationContext.class);
when(jsonParser.getCodec()).thenReturn(objectMapper);

when(objectMapper.readTree(jsonParser)).thenReturn(new TextNode(enumWithoutJsonCreator.toString()));

Enum<?> result = objectUnderTest.deserialize(jsonParser, deserializationContext);

assertThat(result, equalTo(enumWithoutJsonCreator));
}

@Test
void enum_class_with_invalid_value_and_jsonValue_annotation_throws_IllegalArgumentException() throws IOException {
final EnumDeserializer objectUnderTest = createObjectUnderTest(TestEnum.class);
final JsonParser jsonParser = mock(JsonParser.class);
final DeserializationContext deserializationContext = mock(DeserializationContext.class);
when(jsonParser.getCodec()).thenReturn(objectMapper);

final String invalidValue = UUID.randomUUID().toString();
when(objectMapper.readTree(jsonParser)).thenReturn(new TextNode(invalidValue));

final IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () ->
objectUnderTest.deserialize(jsonParser, deserializationContext));

assertThat(exception, notNullValue());
final String expectedErrorMessage = "Invalid value \"" + invalidValue + "\". Valid options include";
assertThat(exception.getMessage(), Matchers.startsWith(expectedErrorMessage));
assertThat(exception.getMessage(), containsString("[test_display_one, test_display_two, test_display_three]"));
}

@Test
void enum_class_with_invalid_value_and_no_jsonValue_annotation_throws_IllegalArgumentException() throws IOException {
final EnumDeserializer objectUnderTest = createObjectUnderTest(TestEnumWithoutJsonCreator.class);
final JsonParser jsonParser = mock(JsonParser.class);
final DeserializationContext deserializationContext = mock(DeserializationContext.class);
when(jsonParser.getCodec()).thenReturn(objectMapper);

final String invalidValue = UUID.randomUUID().toString();
when(objectMapper.readTree(jsonParser)).thenReturn(new TextNode(invalidValue));

final IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () ->
objectUnderTest.deserialize(jsonParser, deserializationContext));

assertThat(exception, notNullValue());
final String expectedErrorMessage = "Invalid value \"" + invalidValue + "\". Valid options include";
assertThat(exception.getMessage(), Matchers.startsWith(expectedErrorMessage));

}

@Test
void create_contextual_returns_expected_enum_deserializer() {
final DeserializationContext context = mock(DeserializationContext.class);
final BeanProperty property = mock(BeanProperty.class);

final ObjectMapper mapper = new ObjectMapper();
final JavaType javaType = mapper.constructType(HandleFailedEventsOption.class);
when(property.getType()).thenReturn(javaType);

final EnumDeserializer objectUnderTest = new EnumDeserializer();
JsonDeserializer<?> result = objectUnderTest.createContextual(context, property);

assertThat(result, instanceOf(EnumDeserializer.class));
}

private enum TestEnum {
TEST_ONE("test_display_one"),
TEST_TWO("test_display_two"),
TEST_THREE("test_display_three");
private static final Map<String, TestEnum> NAMES_MAP = Arrays.stream(TestEnum.values())
.collect(Collectors.toMap(TestEnum::toString, Function.identity()));
private final String name;
TestEnum(final String name) {
this.name = name;
}

@JsonValue
public String toString() {
return this.name;
}
@JsonCreator
static TestEnum fromOptionValue(final String option) {
return NAMES_MAP.get(option);
}
}

private enum TestEnumWithoutJsonCreator {
TEST("test");
private static final Map<String, TestEnumWithoutJsonCreator> NAMES_MAP = Arrays.stream(TestEnumWithoutJsonCreator.values())
.collect(Collectors.toMap(TestEnumWithoutJsonCreator::toString, Function.identity()));
private final String name;
TestEnumWithoutJsonCreator(final String name) {
this.name = name;
}
public String toString() {
return this.name;
}

static TestEnumWithoutJsonCreator fromOptionValue(final String option) {
return NAMES_MAP.get(option);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.dataprepper.model.types.ByteCount;
import org.opensearch.dataprepper.pipeline.parser.ByteCountDeserializer;
import org.opensearch.dataprepper.pipeline.parser.DataPrepperDurationDeserializer;
import org.opensearch.dataprepper.pipeline.parser.EnumDeserializer;
import org.opensearch.dataprepper.pipeline.parser.EventKeyDeserializer;
import org.springframework.context.annotation.Bean;

Expand All @@ -33,6 +34,7 @@ public class ObjectMapperConfiguration {
ObjectMapper extensionPluginConfigObjectMapper() {
final SimpleModule simpleModule = new SimpleModule();
simpleModule.addDeserializer(Duration.class, new DataPrepperDurationDeserializer());
simpleModule.addDeserializer(Enum.class, new EnumDeserializer());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be applied to every Enum defined in the project?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it will be

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any configuration file that has a variable that is an Enum

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice 👍

simpleModule.addDeserializer(ByteCount.class, new ByteCountDeserializer());

return new ObjectMapper()
Expand All @@ -47,6 +49,7 @@ ObjectMapper pluginConfigObjectMapper(
final SimpleModule simpleModule = new SimpleModule();
simpleModule.addDeserializer(Duration.class, new DataPrepperDurationDeserializer());
simpleModule.addDeserializer(ByteCount.class, new ByteCountDeserializer());
simpleModule.addDeserializer(Enum.class, new EnumDeserializer());
simpleModule.addDeserializer(EventKey.class, new EventKeyDeserializer(eventKeyFactory));
TRANSLATE_VALUE_SUPPORTED_JAVA_TYPES.stream().forEach(clazz -> simpleModule.addDeserializer(
clazz, new DataPrepperScalarTypeDeserializer<>(variableExpander, clazz)));
Expand Down
Loading
Loading