Basic Operations
This example demonstrates the fundamental tensor operations in Tensor Frame. It covers tensor creation, basic arithmetic, shape manipulation, and data access patterns.
Complete Example
use tensor_frame::{Tensor, Result, TensorOps}; use std::time::Instant; fn main() -> Result<()> { println!("=== Tensor Frame Basic Operations ===\n"); // 1. Tensor Creation tensor_creation_examples()?; // 2. Basic Arithmetic arithmetic_examples()?; // 3. Shape Manipulation shape_manipulation_examples()?; // 4. Data Access data_access_examples()?; // 5. Performance Comparison performance_comparison()?; Ok(()) } /// Demonstrates various ways to create tensors fn tensor_creation_examples() -> Result<()> { println!("=== Tensor Creation ==="); // Create tensor filled with zeros let zeros = Tensor::zeros(vec![2, 3])?; println!("Zeros tensor (2x3):\n{}\n", zeros); // Create tensor filled with ones let ones = Tensor::ones(vec![3, 2])?; println!("Ones tensor (3x2):\n{}\n", ones); // Create tensor from existing data let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]; let from_data = Tensor::from_vec(data, vec![2, 3])?; println!("From data (2x3):\n{}\n", from_data); // Check tensor properties println!("Tensor properties:"); println!(" Shape: {:?}", from_data.shape().dims()); println!(" Number of elements: {}", from_data.numel()); println!(" Data type: {:?}", from_data.dtype()); println!(" Backend: {:?}\n", from_data.backend_type()); Ok(()) } /// Demonstrates basic arithmetic operations fn arithmetic_examples() -> Result<()> { println!("=== Arithmetic Operations ==="); // Create test tensors let a = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2])?; let b = Tensor::from_vec(vec![5.0, 6.0, 7.0, 8.0], vec![2, 2])?; println!("Tensor A:\n{}\n", a); println!("Tensor B:\n{}\n", b); // Element-wise addition let sum = &a + &b; // Use references to avoid moving tensors println!("A + B:\n{}\n", sum); // Element-wise subtraction let diff = &a - &b; println!("A - B:\n{}\n", diff); // Element-wise multiplication let product = &a * &b; println!("A * B (element-wise):\n{}\n", product); // Element-wise division let quotient = &a / &b; println!("A / B:\n{}\n", quotient); // Chained operations let complex = ((&a * 2.0) + &b) / 3.0; println!("(A * 2 + B) / 3:\n{}\n", complex); Ok(()) } /// Demonstrates shape manipulation operations fn shape_manipulation_examples() -> Result<()> { println!("=== Shape Manipulation ==="); // Create a tensor to manipulate let tensor = Tensor::from_vec( vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], vec![2, 4] )?; println!("Original tensor (2x4):\n{}\n", tensor); // Reshape to different dimensions let reshaped = tensor.reshape(vec![4, 2])?; println!("Reshaped to (4x2):\n{}\n", reshaped); // Reshape to 1D let flattened = tensor.reshape(vec![8])?; println!("Flattened to (8,):\n{}\n", flattened); // Transpose (2D only) let matrix = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2])?; let transposed = matrix.transpose()?; println!("Original matrix:\n{}\n", matrix); println!("Transposed matrix:\n{}\n", transposed); // Squeeze and unsqueeze let with_ones = Tensor::ones(vec![1, 3, 1])?; println!("Tensor with size-1 dimensions (1x3x1):\n{}\n", with_ones); let squeezed = with_ones.squeeze(None)?; println!("Squeezed (removes all size-1 dims):\n{}\n", squeezed); let unsqueezed = squeezed.unsqueeze(0)?; println!("Unsqueezed at dimension 0:\n{}\n", unsqueezed); Ok(()) } /// Demonstrates data access patterns fn data_access_examples() -> Result<()> { println!("=== Data Access ==="); let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2])?; println!("Tensor:\n{}\n", tensor); // Convert to Vec for external use let data = tensor.to_vec()?; println!("As Vec<f32>: {:?}\n", data); // Reduction operations let sum_all = tensor.sum(None)?; println!("Sum of all elements: {}\n", sum_all); let mean_all = tensor.mean(None)?; println!("Mean of all elements: {}\n", mean_all); // Axis-specific reductions let row_sums = tensor.sum(Some(1))?; // Sum along columns (axis 1) println!("Row sums (sum along axis 1): {}\n", row_sums); let col_sums = tensor.sum(Some(0))?; // Sum along rows (axis 0) println!("Column sums (sum along axis 0): {}\n", col_sums); Ok(()) } /// Demonstrates performance characteristics fn performance_comparison() -> Result<()> { println!("=== Performance Comparison ==="); // Small tensor operations (CPU should be faster) let small_a = Tensor::ones(vec![100, 100])?; let small_b = Tensor::ones(vec![100, 100])?; let start = Instant::now(); let result = &small_a + &small_b; let small_time = start.elapsed(); println!("Small tensor (100x100) addition: {:?}", small_time); // Large tensor operations (GPU might be faster if available) let large_a = Tensor::ones(vec![1000, 1000])?; let large_b = Tensor::ones(vec![1000, 1000])?; let start = Instant::now(); let result = &large_a + &large_b; let large_time = start.elapsed(); println!("Large tensor (1000x1000) addition: {:?}", large_time); // Show current backend println!("Current backend: {:?}", result.backend_type()); // Demonstrate backend conversion (if other backends available) #[cfg(feature = "wgpu")] { println!("\n--- WGPU Backend Comparison ---"); let start = Instant::now(); let wgpu_a = large_a.to_backend(tensor_frame::BackendType::Wgpu)?; let wgpu_b = large_b.to_backend(tensor_frame::BackendType::Wgpu)?; let conversion_time = start.elapsed(); let start = Instant::now(); let wgpu_result = &wgpu_a + &wgpu_b; let _sync = wgpu_result.to_vec()?; // Force synchronization let wgpu_time = start.elapsed(); println!("WGPU conversion time: {:?}", conversion_time); println!("WGPU computation time: {:?}", wgpu_time); println!("Total WGPU time: {:?}", conversion_time + wgpu_time); } Ok(()) } /// Advanced patterns demonstration fn advanced_patterns() -> Result<()> { println!("=== Advanced Patterns ==="); // Broadcasting example let matrix = Tensor::ones(vec![3, 4])?; // Shape: [3, 4] let vector = Tensor::ones(vec![4])?; // Shape: [4] let broadcasted = &matrix + &vector; // Result: [3, 4] println!("Matrix (3x4):\n{}\n", matrix); println!("Vector (4,):\n{}\n", vector); println!("Matrix + Vector (broadcasted):\n{}\n", broadcasted); // Complex broadcasting let a = Tensor::ones(vec![2, 1, 3])?; // Shape: [2, 1, 3] let b = Tensor::ones(vec![1, 4, 1])?; // Shape: [1, 4, 1] let complex_broadcast = &a + &b; // Result: [2, 4, 3] println!("Complex broadcasting:"); println!("A shape: {:?}", a.shape().dims()); println!("B shape: {:?}", b.shape().dims()); println!("Result shape: {:?}", complex_broadcast.shape().dims()); // Method chaining let result = Tensor::ones(vec![2, 3])? .reshape(vec![3, 2])? .transpose()?; println!("Method chaining result:\n{}\n", result); Ok(()) } /// Error handling examples fn error_handling_examples() -> Result<()> { println!("=== Error Handling ==="); // Shape mismatch error let a = Tensor::ones(vec![2, 3])?; let b = Tensor::ones(vec![3, 2])?; match &a + &b { Ok(result) => println!("Addition succeeded: {}", result), Err(e) => println!("Expected error - shape mismatch: {}", e), } // Invalid reshape error let tensor = Tensor::ones(vec![2, 3])?; // 6 elements match tensor.reshape(vec![2, 2]) { // 4 elements - invalid! Ok(result) => println!("Reshape succeeded: {}", result), Err(e) => println!("Expected error - invalid reshape: {}", e), } // Out of bounds dimension error match tensor.squeeze(Some(5)) { // Dimension 5 doesn't exist Ok(result) => println!("Squeeze succeeded: {}", result), Err(e) => println!("Expected error - invalid dimension: {}", e), } Ok(()) }
Key Concepts Demonstrated
1. Tensor Creation
Three primary ways to create tensors:
Tensor::zeros(shape)
- Creates tensor filled with zerosTensor::ones(shape)
- Creates tensor filled with onesTensor::from_vec(data, shape)
- Creates tensor from existing data
2. Reference vs. Owned Operations
#![allow(unused)] fn main() { // Moves tensors (can only use once) let result = a + b; // Uses references (can reuse tensors) let result = &a + &b; }
3. Shape Broadcasting
Tensor Frame automatically broadcasts compatible shapes:
#![allow(unused)] fn main() { let matrix = Tensor::ones(vec![3, 4])?; // [3, 4] let vector = Tensor::ones(vec![4])?; // [4] broadcasts to [1, 4] let result = matrix + vector; // Result: [3, 4] }
4. Method Chaining
Operations can be chained for concise code:
#![allow(unused)] fn main() { let result = tensor .reshape(vec![4, 2])? .transpose()? .squeeze(None)?; }
5. Error Handling
All operations return Result<T>
for proper error handling:
#![allow(unused)] fn main() { match risky_operation() { Ok(tensor) => process_tensor(tensor), Err(TensorError::ShapeMismatch { expected, got }) => { eprintln!("Shape error: expected {:?}, got {:?}", expected, got); } Err(e) => eprintln!("Other error: {}", e), } }
Performance Tips
- Use References: Use
&a + &b
instead ofa + b
to avoid unnecessary clones - Batch Operations: Combine operations when possible:
(a * 2.0) + b
vs separate operations - Choose Right Backend: CPU for small tensors, GPU for large operations
- Avoid Frequent Conversions: Stay on one backend when possible
Common Pitfalls
- Shape Mismatches: Ensure compatible shapes for operations
- Invalid Reshapes: New shape must have same total elements
- Backend Overhead: GPU operations have overhead for small tensors
- Memory Usage: Large tensors consume significant memory
Next Steps
After mastering basic operations, explore:
- Broadcasting Examples - Advanced broadcasting patterns
- Backend Selection - Optimizing backend usage
- Performance Guide - Advanced performance optimization