diff --git a/spring-core/src/main/java/org/springframework/util/LinkedCaseInsensitiveMap.java b/spring-core/src/main/java/org/springframework/util/LinkedCaseInsensitiveMap.java index ed55e67805..bf12dabf92 100644 --- a/spring-core/src/main/java/org/springframework/util/LinkedCaseInsensitiveMap.java +++ b/spring-core/src/main/java/org/springframework/util/LinkedCaseInsensitiveMap.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,12 +17,17 @@ package org.springframework.util; import java.io.Serializable; +import java.util.AbstractCollection; +import java.util.AbstractSet; import java.util.Collection; import java.util.HashMap; +import java.util.Iterator; import java.util.LinkedHashMap; import java.util.Locale; import java.util.Map; import java.util.Set; +import java.util.Spliterator; +import java.util.function.Consumer; import java.util.function.Function; import org.springframework.lang.Nullable; @@ -37,6 +42,7 @@ import org.springframework.lang.Nullable; *

Does not support {@code null} keys. * * @author Juergen Hoeller + * @author Phillip Webb * @since 3.0 * @param the value type */ @@ -49,6 +55,12 @@ public class LinkedCaseInsensitiveMap implements Map, Serializable private final Locale locale; + private transient Set keySet; + + private transient Collection values; + + private transient Set> entrySet; + /** * Create a new LinkedCaseInsensitiveMap that stores case-insensitive keys @@ -98,7 +110,7 @@ public class LinkedCaseInsensitiveMap implements Map, Serializable protected boolean removeEldestEntry(Map.Entry eldest) { boolean doRemove = LinkedCaseInsensitiveMap.this.removeEldestEntry(eldest); if (doRemove) { - caseInsensitiveKeys.remove(convertKey(eldest.getKey())); + removeCaseInsensitiveKey(eldest.getKey()); } return doRemove; } @@ -208,7 +220,7 @@ public class LinkedCaseInsensitiveMap implements Map, Serializable @Nullable public V remove(Object key) { if (key instanceof String) { - String caseInsensitiveKey = this.caseInsensitiveKeys.remove(convertKey((String) key)); + String caseInsensitiveKey = removeCaseInsensitiveKey((String) key); if (caseInsensitiveKey != null) { return this.targetMap.remove(caseInsensitiveKey); } @@ -224,17 +236,32 @@ public class LinkedCaseInsensitiveMap implements Map, Serializable @Override public Set keySet() { - return this.targetMap.keySet(); + Set keySet = this.keySet; + if (keySet == null) { + keySet = new KeySet(this.targetMap.keySet()); + this.keySet = keySet; + } + return keySet; } @Override public Collection values() { - return this.targetMap.values(); + Collection values = this.values; + if (values == null) { + values = new Values(this.targetMap.values()); + this.values = values; + } + return values; } @Override public Set> entrySet() { - return this.targetMap.entrySet(); + Set> entrySet = this.entrySet; + if (entrySet == null) { + entrySet = new EntrySet(this.targetMap.entrySet()); + this.entrySet = entrySet; + } + return entrySet; } @Override @@ -293,4 +320,216 @@ public class LinkedCaseInsensitiveMap implements Map, Serializable return false; } + private String removeCaseInsensitiveKey(String key) { + return this.caseInsensitiveKeys.remove(convertKey(key)); + } + + + private class KeySet extends AbstractSet { + + private final Set delegate; + + + KeySet(Set delegate) { + this.delegate = delegate; + } + + + @Override + public int size() { + return this.delegate.size(); + } + + @Override + public boolean contains(Object o) { + return this.delegate.contains(o); + } + + @Override + public Iterator iterator() { + return new KeySetIterator(); + } + + @Override + public boolean remove(Object o) { + return LinkedCaseInsensitiveMap.this.remove(o) != null; + } + + @Override + public void clear() { + LinkedCaseInsensitiveMap.this.clear(); + } + + @Override + public Spliterator spliterator() { + return this.delegate.spliterator(); + } + + @Override + public void forEach(Consumer action) { + this.delegate.forEach(action); + } + + } + + + private class Values extends AbstractCollection { + + private final Collection delegate; + + + Values(Collection delegate) { + this.delegate = delegate; + } + + + @Override + public int size() { + return this.delegate.size(); + } + + @Override + public boolean contains(Object o) { + return this.delegate.contains(o); + } + + @Override + public Iterator iterator() { + return new ValuesIterator(); + } + + @Override + public void clear() { + LinkedCaseInsensitiveMap.this.clear(); + } + + @Override + public Spliterator spliterator() { + return this.delegate.spliterator(); + } + + @Override + public void forEach(Consumer action) { + this.delegate.forEach(action); + } + + } + + + private class EntrySet extends AbstractSet> { + + private final Set> delegate; + + + public EntrySet(Set> delegate) { + this.delegate = delegate; + } + + + @Override + public int size() { + return this.delegate.size(); + } + + @Override + public boolean contains(Object o) { + return this.delegate.contains(o); + } + + @Override + public Iterator> iterator() { + return new EntrySetIterator(); + } + + + @Override + @SuppressWarnings("unchecked") + public boolean remove(Object o) { + if (this.delegate.remove(o)) { + removeCaseInsensitiveKey(((Map.Entry) o).getKey()); + return true; + } + return false; + } + + + @Override + public void clear() { + this.delegate.clear(); + caseInsensitiveKeys.clear(); + } + + @Override + public Spliterator> spliterator() { + return this.delegate.spliterator(); + } + + @Override + public void forEach(Consumer> action) { + this.delegate.forEach(action); + } + + } + + + private class EntryIterator { + + private final Iterator> delegate; + + private Entry last; + + public EntryIterator() { + this.delegate = targetMap.entrySet().iterator(); + } + + public Entry nextEntry() { + Entry entry = this.delegate.next(); + this.last = entry; + return entry; + } + + public boolean hasNext() { + return this.delegate.hasNext(); + } + + public void remove() { + this.delegate.remove(); + if(this.last != null) { + removeCaseInsensitiveKey(this.last.getKey()); + this.last = null; + } + } + + } + + + private class KeySetIterator extends EntryIterator implements Iterator { + + @Override + public String next() { + return nextEntry().getKey(); + } + + } + + + private class ValuesIterator extends EntryIterator implements Iterator { + + @Override + public V next() { + return nextEntry().getValue(); + } + + } + + + private class EntrySetIterator extends EntryIterator implements Iterator> { + + @Override + public Entry next() { + return nextEntry(); + } + + } + } diff --git a/spring-core/src/test/java/org/springframework/util/LinkedCaseInsensitiveMapTests.java b/spring-core/src/test/java/org/springframework/util/LinkedCaseInsensitiveMapTests.java index f698f764c6..c93324bea9 100644 --- a/spring-core/src/test/java/org/springframework/util/LinkedCaseInsensitiveMapTests.java +++ b/spring-core/src/test/java/org/springframework/util/LinkedCaseInsensitiveMapTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,12 +16,17 @@ package org.springframework.util; +import java.util.Iterator; + import org.junit.Test; import static org.junit.Assert.*; /** + * Tests for {@link LinkedCaseInsensitiveMap}. + * * @author Juergen Hoeller + * @author Phillip Webb */ public class LinkedCaseInsensitiveMapTests { @@ -127,4 +132,89 @@ public class LinkedCaseInsensitiveMapTests { assertEquals("value2", copy.get("Key")); } + + @Test + public void clearFromKeySet() { + map.put("key", "value"); + map.keySet().clear(); + map.computeIfAbsent("key", k -> "newvalue"); + assertEquals("newvalue", map.get("key")); + } + + @Test + public void removeFromKeySet() { + map.put("key", "value"); + map.keySet().remove("key"); + map.computeIfAbsent("key", k -> "newvalue"); + assertEquals("newvalue", map.get("key")); + } + + @Test + public void removeFromKeySetViaIterator() { + map.put("key", "value"); + nextAndRemove(map.keySet().iterator()); + assertEquals(0, map.size()); + map.computeIfAbsent("key", k -> "newvalue"); + assertEquals("newvalue", map.get("key")); + } + + @Test + public void clearFromValues() { + map.put("key", "value"); + map.values().clear(); + assertEquals(0, map.size()); + map.computeIfAbsent("key", k -> "newvalue"); + assertEquals("newvalue", map.get("key")); + } + + @Test + public void removeFromValues() { + map.put("key", "value"); + map.values().remove("value"); + assertEquals(0, map.size()); + map.computeIfAbsent("key", k -> "newvalue"); + assertEquals("newvalue", map.get("key")); + } + + @Test + public void removeFromValuesViaIterator() { + map.put("key", "value"); + nextAndRemove(map.values().iterator()); + assertEquals(0, map.size()); + map.computeIfAbsent("key", k -> "newvalue"); + assertEquals("newvalue", map.get("key")); + } + + @Test + public void clearFromEntrySet() { + map.put("key", "value"); + map.entrySet().clear(); + assertEquals(0, map.size()); + map.computeIfAbsent("key", k -> "newvalue"); + assertEquals("newvalue", map.get("key")); + } + + @Test + public void removeFromEntrySet() { + map.put("key", "value"); + map.entrySet().remove(map.entrySet().iterator().next()); + assertEquals(0, map.size()); + map.computeIfAbsent("key", k -> "newvalue"); + assertEquals("newvalue", map.get("key")); + } + + @Test + public void removeFromEntrySetViaIterator() { + map.put("key", "value"); + nextAndRemove(map.entrySet().iterator()); + assertEquals(0, map.size()); + map.computeIfAbsent("key", k -> "newvalue"); + assertEquals("newvalue", map.get("key")); + } + + private void nextAndRemove(Iterator iterator) { + iterator.next(); + iterator.remove(); + } + } diff --git a/spring-web/src/test/java/org/springframework/http/HttpHeadersTests.java b/spring-web/src/test/java/org/springframework/http/HttpHeadersTests.java index b0717250b8..d72e698112 100644 --- a/spring-web/src/test/java/org/springframework/http/HttpHeadersTests.java +++ b/spring-web/src/test/java/org/springframework/http/HttpHeadersTests.java @@ -562,7 +562,6 @@ public class HttpHeadersTests { } @Test - @Ignore("Disabled until gh-22821 is resolved") public void removalFromKeySetRemovesEntryFromUnderlyingMap() { String headerName = "MyHeader"; String headerValue = "value"; @@ -573,11 +572,10 @@ public class HttpHeadersTests { headers.keySet().removeIf(key -> key.equals(headerName)); assertTrue(headers.isEmpty()); headers.add(headerName, headerValue); - assertEquals(headerValue, headers.get(headerName)); + assertEquals(headerValue, headers.get(headerName).get(0)); } @Test - @Ignore("Disabled until gh-22821 is resolved") public void removalFromEntrySetRemovesEntryFromUnderlyingMap() { String headerName = "MyHeader"; String headerValue = "value"; @@ -588,7 +586,7 @@ public class HttpHeadersTests { headers.entrySet().removeIf(entry -> entry.getKey().equals(headerName)); assertTrue(headers.isEmpty()); headers.add(headerName, headerValue); - assertEquals(headerValue, headers.get(headerName)); + assertEquals(headerValue, headers.get(headerName).get(0)); } }