Browse Source

Fix race condition in transition from UNSUBSCRIBED->COMPLETED

- Ensure completion signal (normal/exception) will be delivered to
the subscriber when transition from UNSUBSCRIBED->COMPLETED

- According to the specification "Publisher.subscribe MUST call onSubscribe
on the provided Subscriber prior to any other signals to that Subscriber" so
ensure onComplete/onError signals will be called AFTER onSubscribe signal.

Issue: SPR-16207
pull/1598/merge
Violeta Georgieva 7 years ago committed by Rossen Stoyanchev
parent
commit
b814875211
  1. 46
      spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerReadPublisher.java
  2. 40
      spring-web/src/main/java/org/springframework/http/server/reactive/WriteResultPublisher.java

46
spring-web/src/main/java/org/springframework/http/server/reactive/AbstractListenerReadPublisher.java

@ -52,6 +52,11 @@ public abstract class AbstractListenerReadPublisher<T> implements Publisher<T> {
private volatile long demand; private volatile long demand;
private volatile boolean publisherCompleted;
@Nullable
private volatile Throwable publisherError;
@SuppressWarnings("rawtypes") @SuppressWarnings("rawtypes")
private static final AtomicLongFieldUpdater<AbstractListenerReadPublisher> DEMAND_FIELD_UPDATER = private static final AtomicLongFieldUpdater<AbstractListenerReadPublisher> DEMAND_FIELD_UPDATER =
AtomicLongFieldUpdater.newUpdater(AbstractListenerReadPublisher.class, "demand"); AtomicLongFieldUpdater.newUpdater(AbstractListenerReadPublisher.class, "demand");
@ -208,15 +213,54 @@ public abstract class AbstractListenerReadPublisher<T> implements Publisher<T> {
<T> void subscribe(AbstractListenerReadPublisher<T> publisher, Subscriber<? super T> subscriber) { <T> void subscribe(AbstractListenerReadPublisher<T> publisher, Subscriber<? super T> subscriber) {
Assert.notNull(publisher, "Publisher must not be null"); Assert.notNull(publisher, "Publisher must not be null");
Assert.notNull(subscriber, "Subscriber must not be null"); Assert.notNull(subscriber, "Subscriber must not be null");
if (publisher.changeState(this, NO_DEMAND)) { if (publisher.changeState(this, SUBSCRIBING)) {
Subscription subscription = new ReadSubscription(publisher); Subscription subscription = new ReadSubscription(publisher);
publisher.subscriber = subscriber; publisher.subscriber = subscriber;
subscriber.onSubscribe(subscription); subscriber.onSubscribe(subscription);
publisher.changeState(SUBSCRIBING, NO_DEMAND);
if (publisher.publisherCompleted) {
publisher.onAllDataRead();
}
Throwable publisherError = publisher.publisherError;
if (publisherError != null) {
publisher.onError(publisherError);
}
} }
else { else {
throw new IllegalStateException(toString()); throw new IllegalStateException(toString());
} }
} }
@Override
<T> void onAllDataRead(AbstractListenerReadPublisher<T> publisher) {
publisher.publisherCompleted = true;
}
@Override
<T> void onError(AbstractListenerReadPublisher<T> publisher, Throwable t) {
publisher.publisherError = t;
}
},
SUBSCRIBING {
<T> void request(AbstractListenerReadPublisher<T> publisher, long n) {
if (Operators.validate(n)) {
Operators.addCap(DEMAND_FIELD_UPDATER, publisher, n);
if (publisher.changeState(this, DEMAND)) {
publisher.checkOnDataAvailable();
}
}
}
@Override
<T> void onAllDataRead(AbstractListenerReadPublisher<T> publisher) {
publisher.publisherCompleted = true;
}
@Override
<T> void onError(AbstractListenerReadPublisher<T> publisher, Throwable t) {
publisher.publisherError = t;
}
}, },
/** /**

40
spring-web/src/main/java/org/springframework/http/server/reactive/WriteResultPublisher.java

@ -119,18 +119,17 @@ class WriteResultPublisher implements Publisher<Void> {
@Override @Override
void subscribe(WriteResultPublisher publisher, Subscriber<? super Void> subscriber) { void subscribe(WriteResultPublisher publisher, Subscriber<? super Void> subscriber) {
Assert.notNull(subscriber, "Subscriber must not be null"); Assert.notNull(subscriber, "Subscriber must not be null");
publisher.subscriber = subscriber; if (publisher.changeState(this, SUBSCRIBING)) {
if (publisher.changeState(this, SUBSCRIBED)) {
Subscription subscription = new ResponseBodyWriteResultSubscription(publisher); Subscription subscription = new ResponseBodyWriteResultSubscription(publisher);
publisher.subscriber = subscriber;
subscriber.onSubscribe(subscription); subscriber.onSubscribe(subscription);
publisher.changeState(SUBSCRIBING, SUBSCRIBED);
if (publisher.publisherCompleted) { if (publisher.publisherCompleted) {
publisher.publishComplete(); publisher.publishComplete();
} }
else { Throwable publisherError = publisher.publisherError;
Throwable publisherError = publisher.publisherError; if (publisherError != null) {
if (publisherError != null) { publisher.publishError(publisherError);
publisher.publishError(publisherError);
}
} }
} }
else { else {
@ -147,6 +146,21 @@ class WriteResultPublisher implements Publisher<Void> {
} }
}, },
SUBSCRIBING {
@Override
void request(WriteResultPublisher publisher, long n) {
Operators.validate(n);
}
@Override
void publishComplete(WriteResultPublisher publisher) {
publisher.publisherCompleted = true;
}
@Override
void publishError(WriteResultPublisher publisher, Throwable t) {
publisher.publisherError = t;
}
},
SUBSCRIBED { SUBSCRIBED {
@Override @Override
void request(WriteResultPublisher publisher, long n) { void request(WriteResultPublisher publisher, long n) {
@ -183,14 +197,6 @@ class WriteResultPublisher implements Publisher<Void> {
void cancel(WriteResultPublisher publisher) { void cancel(WriteResultPublisher publisher) {
// ignore // ignore
} }
@Override
void publishComplete(WriteResultPublisher publisher) {
// ignore
}
@Override
void publishError(WriteResultPublisher publisher, Throwable t) {
// ignore
}
}; };
void subscribe(WriteResultPublisher publisher, Subscriber<? super Void> subscriber) { void subscribe(WriteResultPublisher publisher, Subscriber<? super Void> subscriber) {
@ -208,11 +214,11 @@ class WriteResultPublisher implements Publisher<Void> {
} }
void publishComplete(WriteResultPublisher publisher) { void publishComplete(WriteResultPublisher publisher) {
throw new IllegalStateException(toString()); // ignore
} }
void publishError(WriteResultPublisher publisher, Throwable t) { void publishError(WriteResultPublisher publisher, Throwable t) {
throw new IllegalStateException(toString()); // ignore
} }
} }

Loading…
Cancel
Save