1use super::{CloneIn, Gc, Object, Symbol, TagType};
8use crate::{
9 channel_manager::{self, ChannelId, RecvError, SendError},
10 core::{
11 env::sym,
12 gc::{Block, Context, GcHeap, GcState, Trace},
13 },
14 derive_GcMoveable,
15};
16use std::{
17 fmt,
18 sync::Arc,
19 time::{Duration, Instant},
20};
21
22impl From<SendError> for Symbol<'static> {
24 fn from(err: SendError) -> Self {
25 match err {
26 SendError::Closed => sym::CHANNEL_CLOSED,
27 SendError::Full => sym::CHANNEL_FULL,
28 SendError::Timeout => sym::CHANNEL_TIMEOUT,
29 }
30 }
31}
32
33impl From<RecvError> for Symbol<'static> {
34 fn from(err: RecvError) -> Self {
35 match err {
36 RecvError::Closed => sym::CHANNEL_CLOSED,
37 RecvError::Empty => sym::CHANNEL_EMPTY,
38 RecvError::Timeout => sym::CHANNEL_TIMEOUT,
39 }
40 }
41}
42
43pub(crate) struct ChannelSender(pub(in crate::core) GcHeap<ChannelSenderInner>);
45
46derive_GcMoveable!(ChannelSender);
47
48pub(in crate::core) struct ChannelSenderInner {
49 pub(in crate::core) channel_id: ChannelId,
50 pub(in crate::core) manager: Arc<channel_manager::ChannelManager>,
51}
52
53impl ChannelSender {
54 pub(in crate::core) fn new(
55 channel_id: ChannelId,
56 manager: Arc<channel_manager::ChannelManager>,
57 constant: bool,
58 ) -> Self {
59 manager.increment_sender(channel_id);
60 Self(GcHeap::new(ChannelSenderInner { channel_id, manager }, constant))
61 }
62
63 pub(crate) fn send<'ob>(&self, obj: Object<'ob>) -> Result<(), SendError> {
66 self.0.manager.send(self.0.channel_id, obj)
67 }
68
69 pub(crate) fn try_send<'ob>(&self, obj: Object<'ob>) -> Result<(), SendError> {
73 self.0.manager.try_send(self.0.channel_id, obj)
74 }
75
76 pub(crate) fn send_timeout<'ob>(
80 &self,
81 obj: Object<'ob>,
82 timeout: Duration,
83 ) -> Result<(), SendError> {
84 let start = Instant::now();
85
86 loop {
87 match self.try_send(obj) {
88 Ok(()) => return Ok(()),
89 Err(SendError::Full) => {
90 if start.elapsed() >= timeout {
91 return Err(SendError::Timeout);
92 }
93 std::thread::sleep(Duration::from_millis(1));
95 }
96 Err(e) => return Err(e),
97 }
98 }
99 }
100
101 pub(crate) fn close(&self) {
105 self.0.manager.close_sender(self.0.channel_id);
106 }
107}
108
109pub(crate) struct ChannelReceiver(pub(in crate::core) GcHeap<ChannelReceiverInner>);
111
112derive_GcMoveable!(ChannelReceiver);
113
114pub(in crate::core) struct ChannelReceiverInner {
115 pub(in crate::core) channel_id: ChannelId,
116 pub(in crate::core) manager: Arc<channel_manager::ChannelManager>,
117}
118
119impl ChannelReceiver {
120 pub(in crate::core) fn new(
121 channel_id: ChannelId,
122 manager: Arc<channel_manager::ChannelManager>,
123 constant: bool,
124 ) -> Self {
125 manager.increment_receiver(channel_id);
126 Self(GcHeap::new(ChannelReceiverInner { channel_id, manager }, constant))
127 }
128
129 pub(crate) fn recv<'ob>(&self, cx: &'ob Context) -> Result<Object<'ob>, RecvError> {
132 self.0.manager.recv(self.0.channel_id, &cx.block)
133 }
134
135 pub(crate) fn try_recv<'ob>(&self, cx: &'ob Context) -> Result<Object<'ob>, RecvError> {
139 self.0.manager.try_recv(self.0.channel_id, &cx.block)
140 }
141
142 pub(crate) fn recv_timeout<'ob>(
146 &self,
147 cx: &'ob Context,
148 timeout: Duration,
149 ) -> Result<Object<'ob>, RecvError> {
150 let start = Instant::now();
151
152 loop {
153 match self.try_recv(cx) {
154 Ok(obj) => return Ok(obj),
155 Err(RecvError::Empty) => {
156 if start.elapsed() >= timeout {
157 return Err(RecvError::Timeout);
158 }
159 std::thread::sleep(Duration::from_millis(1));
161 }
162 Err(e) => return Err(e),
163 }
164 }
165 }
166
167 pub(crate) fn close(&self) {
174 self.0.manager.close_receiver(self.0.channel_id);
175 }
176}
177
178pub(crate) fn make_channel_pair(capacity: usize) -> (ChannelSender, ChannelReceiver) {
180 let manager = channel_manager::get_manager();
181 let (sender_id, receiver_id) = manager.new_channel_pair(capacity);
182
183 let sender = ChannelSender::new(sender_id, manager.clone(), false);
184 let receiver = ChannelReceiver::new(receiver_id, manager, false);
185 (sender, receiver)
186}
187
188impl Drop for ChannelSenderInner {
190 fn drop(&mut self) {
191 self.manager.close_sender(self.channel_id);
192 self.manager.cleanup_channel(self.channel_id);
193 }
194}
195
196impl Drop for ChannelReceiverInner {
197 fn drop(&mut self) {
198 self.manager.close_receiver(self.channel_id);
199 self.manager.cleanup_channel(self.channel_id);
200 }
201}
202
203impl Trace for ChannelSenderInner {
205 fn trace(&self, _state: &mut GcState) {
206 }
208}
209
210impl Trace for ChannelReceiverInner {
211 fn trace(&self, _state: &mut GcState) {
212 }
214}
215
216impl<'new> CloneIn<'new, &'new Self> for ChannelSender {
218 fn clone_in<const C: bool>(&self, _bk: &'new Block<C>) -> Gc<&'new Self> {
219 let new_sender = ChannelSender::new(self.0.channel_id, self.0.manager.clone(), false);
220
221 unsafe { std::mem::transmute::<Gc<&Self>, Gc<&'new Self>>((&new_sender).tag()) }
223 }
224}
225
226impl<'new> CloneIn<'new, &'new Self> for ChannelReceiver {
227 fn clone_in<const C: bool>(&self, _bk: &'new Block<C>) -> Gc<&'new Self> {
228 let new_receiver = ChannelReceiver::new(self.0.channel_id, self.0.manager.clone(), false);
229
230 unsafe { std::mem::transmute::<Gc<&Self>, Gc<&'new Self>>((&new_receiver).tag()) }
232 }
233}
234
235impl fmt::Debug for ChannelSender {
239 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
240 write!(f, "#<channel-sender:{}>", self.0.channel_id)
241 }
242}
243
244impl fmt::Display for ChannelSender {
245 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
246 write!(f, "#<channel-sender>")
247 }
248}
249
250impl fmt::Debug for ChannelReceiver {
251 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
252 write!(f, "#<channel-receiver:{}>", self.0.channel_id)
253 }
254}
255
256impl fmt::Display for ChannelReceiver {
257 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
258 write!(f, "#<channel-receiver>")
259 }
260}
261
262impl PartialEq for ChannelSender {
264 fn eq(&self, other: &Self) -> bool {
265 std::ptr::eq(&*self.0, &*other.0)
266 }
267}
268
269impl Eq for ChannelSender {}
270
271impl PartialEq for ChannelReceiver {
272 fn eq(&self, other: &Self) -> bool {
273 std::ptr::eq(&*self.0, &*other.0)
274 }
275}
276
277impl Eq for ChannelReceiver {}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use crate::core::{
283 cons::Cons,
284 gc::{Context, RootSet},
285 object::{IntoObject, ObjectType},
286 };
287
288 #[test]
289 fn test_basic_send_recv() {
290 let roots = RootSet::default();
291 let cx = Context::new(&roots);
292
293 let (sender, receiver) = make_channel_pair(1);
294
295 sender.send(cx.add(42)).unwrap();
297
298 let result = receiver.recv(&cx).unwrap();
300 if let ObjectType::Int(n) = result.untag() {
301 assert_eq!(n, 42);
302 } else {
303 panic!("Expected integer");
304 }
305 }
306
307 #[test]
308 fn test_double_clone_verification() {
309 let roots = RootSet::default();
310 let cx = Context::new(&roots);
311
312 let (sender, receiver) = make_channel_pair(1);
313
314 let original = cx.add("test string");
316
317 sender.send(original).unwrap();
319
320 let received = receiver.recv(&cx).unwrap();
322
323 if let ObjectType::String(s1) = original.untag() {
325 if let ObjectType::String(s2) = received.untag() {
326 assert_eq!(s1.as_ref(), s2.as_ref());
327
328 assert_ne!(
330 s1.as_ptr() as usize,
331 s2.as_ptr() as usize,
332 "Strings should be different allocations"
333 );
334 } else {
335 panic!("Expected string");
336 }
337 } else {
338 panic!("Expected string");
339 }
340 }
341
342 #[test]
343 fn test_channel_closed_on_sender_drop() {
344 let roots = RootSet::default();
345 let cx = Context::new(&roots);
346
347 let (sender, receiver) = make_channel_pair(1);
348
349 drop(sender);
351
352 let result = receiver.recv(&cx);
354 assert!(matches!(result, Err(RecvError::Closed)));
355 }
356
357 #[test]
358 fn test_channel_full() {
359 let roots = RootSet::default();
360 let cx = Context::new(&roots);
361
362 let (sender, receiver) = make_channel_pair(1);
363
364 sender.send(cx.add(1)).unwrap();
366
367 let result = sender.try_send(cx.add(2));
369 assert!(matches!(result, Err(SendError::Full)));
370
371 let _ = receiver.recv(&cx).unwrap();
373
374 sender.try_send(cx.add(2)).unwrap();
376
377 let _ = receiver.recv(&cx).unwrap();
379 }
380
381 #[test]
382 fn test_channel_empty() {
383 let roots = RootSet::default();
384 let cx = Context::new(&roots);
385
386 let (_sender, receiver) = make_channel_pair(1);
387
388 let result = receiver.try_recv(&cx);
390 assert!(matches!(result, Err(RecvError::Empty)));
391 }
392
393 #[test]
394 fn test_complex_objects() {
395 let roots = RootSet::default();
396 let cx = Context::new(&roots);
397
398 let (sender, receiver) = make_channel_pair(5);
399
400 let str_obj = cx.add("hello");
402 sender.send(str_obj).unwrap();
403
404 let list = Cons::new(cx.add(1), cx.add(2), &cx).into_obj(&cx.block);
406 sender.send(list.into()).unwrap();
407
408 let vec_obj = cx.add(vec![cx.add(10), cx.add(20), cx.add(30)]);
410 sender.send(vec_obj).unwrap();
411
412 let recv_str = receiver.recv(&cx).unwrap();
414 if let ObjectType::String(s) = recv_str.untag() {
415 assert_eq!(s.as_ref(), "hello");
416 } else {
417 panic!("Expected string");
418 }
419
420 let recv_list = receiver.recv(&cx).unwrap();
422 if let ObjectType::Cons(cons) = recv_list.untag() {
423 let car_obj = cons.car();
424 if let ObjectType::Int(car) = car_obj.untag() {
425 assert_eq!(car, 1);
426 } else {
427 panic!("Expected int in car");
428 }
429 } else {
430 panic!("Expected cons");
431 }
432
433 let recv_vec = receiver.recv(&cx).unwrap();
435 if let ObjectType::Vec(v) = recv_vec.untag() {
436 assert_eq!(v.len(), 3);
437 if let ObjectType::Int(n) = v[0].get().untag() {
438 assert_eq!(n, 10);
439 }
440 } else {
441 panic!("Expected vector");
442 }
443 }
444
445 #[test]
446 fn test_timeout() {
447 let roots = RootSet::default();
448 let cx = Context::new(&roots);
449
450 let (_sender, receiver) = make_channel_pair(1);
451
452 let result = receiver.recv_timeout(&cx, Duration::from_millis(10));
454 assert!(matches!(result, Err(RecvError::Timeout)));
455 }
456
457 #[test]
458 fn test_multiple_messages() {
459 let roots = RootSet::default();
460 let cx = Context::new(&roots);
461
462 let (sender, receiver) = make_channel_pair(10);
463
464 for i in 0..10 {
466 sender.send(cx.add(i)).unwrap();
467 }
468
469 for i in 0..10 {
471 let result = receiver.recv(&cx).unwrap();
472 if let ObjectType::Int(n) = result.untag() {
473 assert_eq!(n, i);
474 } else {
475 panic!("Expected integer");
476 }
477 }
478 }
479
480 #[test]
481 fn test_channel_sender_is_send() {
482 fn assert_send<T: Send>() {}
483 assert_send::<ChannelSender>();
484 }
485
486 #[test]
487 fn test_channel_receiver_is_send() {
488 fn assert_send<T: Send>() {}
489 assert_send::<ChannelReceiver>();
490 }
491
492 #[test]
493 fn test_cross_thread_channel() {
494 use std::thread;
495
496 let roots = RootSet::default();
497 let cx = Context::new(&roots);
498
499 let (sender, receiver) = make_channel_pair(5);
500
501 let handle = thread::spawn(move || {
503 let roots2 = RootSet::default();
504 let cx2 = Context::new(&roots2);
505 for i in 0..5 {
506 sender.send(cx2.add(i * 10)).unwrap();
507 }
508 });
509
510 for i in 0..5 {
512 let result = receiver.recv(&cx).unwrap();
513 if let ObjectType::Int(n) = result.untag() {
514 assert_eq!(n, i * 10);
515 } else {
516 panic!("Expected integer");
517 }
518 }
519
520 handle.join().unwrap();
521 }
522
523 #[test]
524 fn test_concurrent_channel_stress() {
525 use std::sync::Arc;
526 use std::sync::atomic::{AtomicUsize, Ordering};
527 use std::thread;
528
529 let roots = RootSet::default();
530 let _cx = Context::new(&roots);
531
532 let send_count = Arc::new(AtomicUsize::new(0));
533 let recv_count = Arc::new(AtomicUsize::new(0));
534
535 let mut handles = vec![];
537 for idx in 0..4 {
538 let (sender, receiver) = make_channel_pair(10);
539 let send_counter = Arc::clone(&send_count);
540 let recv_counter = Arc::clone(&recv_count);
541
542 handles.push(thread::spawn(move || {
544 let roots = RootSet::default();
545 let cx = Context::new(&roots);
546 for i in 0..20 {
547 let value = (idx * 100 + i) as i64;
548 sender.send(cx.add(value)).unwrap();
549 send_counter.fetch_add(1, Ordering::SeqCst);
550 }
551 }));
552
553 handles.push(thread::spawn(move || {
555 let roots = RootSet::default();
556 let cx = Context::new(&roots);
557 for _ in 0..20 {
558 let result = receiver.recv(&cx).unwrap();
559 assert!(matches!(result.untag(), ObjectType::Int(_)));
561 recv_counter.fetch_add(1, Ordering::SeqCst);
562 }
563 }));
564 }
565
566 for handle in handles {
568 handle.join().unwrap();
569 }
570
571 assert_eq!(send_count.load(Ordering::SeqCst), 80);
573 assert_eq!(recv_count.load(Ordering::SeqCst), 80);
574 }
575}